mirror of https://github.com/ospab/ostp.git
feat: implement robust multiplexing, high-latency timeouts, and dynamic background reconnects for mobile network stability
This commit is contained in:
parent
3a4b5a8c63
commit
8a2af5d73d
Binary file not shown.
|
|
@ -16,6 +16,22 @@ use crate::app::{BridgeCommand, ConnectionStatus, UiEvent};
|
|||
use crate::config::ClientConfig;
|
||||
use crate::tunnel::{ProxyEvent, ProxyToClientMsg};
|
||||
|
||||
static SOCKET_PROTECTOR: std::sync::OnceLock<Box<dyn Fn(i32) -> bool + Send + Sync>> = std::sync::OnceLock::new();
|
||||
|
||||
pub fn set_socket_protector<F>(f: F)
|
||||
where
|
||||
F: Fn(i32) -> bool + Send + Sync + 'static,
|
||||
{
|
||||
let _ = SOCKET_PROTECTOR.set(Box::new(f));
|
||||
}
|
||||
|
||||
pub fn protect_socket(fd: i32) -> bool {
|
||||
if let Some(f) = SOCKET_PROTECTOR.get() {
|
||||
return f(fd);
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub struct BridgeMetrics {
|
||||
pub bytes_sent: AtomicU64,
|
||||
pub bytes_recv: AtomicU64,
|
||||
|
|
@ -238,12 +254,13 @@ impl Bridge {
|
|||
let (udp_tx, udp_rx) = mpsc::channel(100000); // Increased for high-speed traffic stability
|
||||
let mut sessions = Vec::with_capacity(session_count);
|
||||
let mut rtt_sum = 0.0;
|
||||
let mut successful_sessions = 0;
|
||||
|
||||
let mut handshake_error = None;
|
||||
for idx in 0..session_count {
|
||||
let session_id: u32 = rand::thread_rng().gen();
|
||||
match self.perform_handshake_with_id(&tx, session_id).await {
|
||||
Ok((sock, mach, rtt)) => {
|
||||
let session_index = sessions.len();
|
||||
let socket = Arc::new(sock);
|
||||
let socket_clone = socket.clone();
|
||||
let udp_tx_clone = udp_tx.clone();
|
||||
|
|
@ -263,12 +280,16 @@ impl Bridge {
|
|||
} else {
|
||||
Bytes::copy_from_slice(&buf[..n])
|
||||
};
|
||||
if udp_tx_clone.send((idx, inbound)).await.is_err() {
|
||||
if udp_tx_clone.send((session_index, inbound)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
break;
|
||||
Err(e) => {
|
||||
// Under Windows/Winsock, transient UDP socket errors (like WSAECONNRESET) are returned
|
||||
// as Err(ConnectionReset). We MUST NOT break the loop on transient errors, otherwise the
|
||||
// download path will be permanently killed while the upload path keeps running.
|
||||
tracing::warn!("UDP socket recv error (session {}): {}", session_index, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -276,17 +297,17 @@ impl Bridge {
|
|||
|
||||
sessions.push(SessionState { socket, machine: mach });
|
||||
rtt_sum += rtt;
|
||||
successful_sessions += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
handshake_error = Some(err);
|
||||
break;
|
||||
tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(err) = handshake_error {
|
||||
if sessions.is_empty() {
|
||||
_proxy_guard = None;
|
||||
tx.send(UiEvent::Log(format!("Connection failed: {err}"))).await.ok();
|
||||
tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok();
|
||||
tx.send(UiEvent::TunnelStopped).await.ok();
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
continue;
|
||||
|
|
@ -294,7 +315,7 @@ impl Bridge {
|
|||
|
||||
udp_rx_opt = Some(udp_rx);
|
||||
sessions_opt = Some(sessions);
|
||||
self.last_rtt_ms = rtt_sum / session_count as f64;
|
||||
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
|
||||
self.running = true;
|
||||
self.last_sample_at = Instant::now();
|
||||
self.last_valid_recv = Instant::now();
|
||||
|
|
@ -352,19 +373,89 @@ impl Bridge {
|
|||
}
|
||||
_ = keepalive_tick.tick() => {
|
||||
if self.running {
|
||||
// 1. Connection Liveness Check
|
||||
if self.last_valid_recv.elapsed().as_secs() > 60 {
|
||||
let _ = tx.send(UiEvent::Log("Connection lost (timeout). Reconnecting...".into())).await;
|
||||
// 1. Connection Liveness Check & Silent Background Reconnect
|
||||
if self.last_valid_recv.elapsed().as_secs() > 25 {
|
||||
let elapsed = self.last_valid_recv.elapsed().as_secs();
|
||||
if elapsed > 180 {
|
||||
// Hard timeout after 3 minutes of total silence
|
||||
let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await;
|
||||
self.running = false;
|
||||
_proxy_guard = None;
|
||||
sessions_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "keepalive timeout");
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout");
|
||||
let _ = tx.send(UiEvent::TunnelStopped).await;
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
continue;
|
||||
}
|
||||
|
||||
let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await;
|
||||
self.metrics.connection_state.store(1, Ordering::Relaxed); // State: Connecting (Handshake)
|
||||
|
||||
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
|
||||
let (udp_tx, udp_rx) = mpsc::channel(100000);
|
||||
let mut new_sessions = Vec::with_capacity(session_count);
|
||||
let mut successful_sessions = 0;
|
||||
let mut rtt_sum = 0.0;
|
||||
|
||||
for idx in 0..session_count {
|
||||
let session_id: u32 = rand::thread_rng().gen();
|
||||
match self.perform_handshake_with_id(&tx, session_id).await {
|
||||
Ok((sock, mach, rtt)) => {
|
||||
let session_index = new_sessions.len();
|
||||
let socket = Arc::new(sock);
|
||||
let socket_clone = socket.clone();
|
||||
let udp_tx_clone = udp_tx.clone();
|
||||
let is_turn = self.turn_enabled;
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0_u8; 65535];
|
||||
loop {
|
||||
match socket_clone.recv(&mut buf).await {
|
||||
Ok(n) => {
|
||||
let inbound = if is_turn && n >= 4 && buf[0] == 0x40 && buf[1] == 0x00 {
|
||||
let len = u16::from_be_bytes([buf[2], buf[3]]) as usize;
|
||||
if 4 + len <= n {
|
||||
Bytes::copy_from_slice(&buf[4..4+len])
|
||||
} else {
|
||||
Bytes::copy_from_slice(&buf[..n])
|
||||
}
|
||||
} else {
|
||||
Bytes::copy_from_slice(&buf[..n])
|
||||
};
|
||||
if udp_tx_clone.send((session_index, inbound)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
new_sessions.push(SessionState { socket, machine: mach });
|
||||
rtt_sum += rtt;
|
||||
successful_sessions += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !new_sessions.is_empty() {
|
||||
sessions_opt = Some(new_sessions);
|
||||
udp_rx_opt = Some(udp_rx);
|
||||
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
|
||||
self.last_valid_recv = Instant::now();
|
||||
self.metrics.connection_state.store(2, Ordering::Relaxed); // State: Connected
|
||||
let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await;
|
||||
} else {
|
||||
let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Active Keep-Alive / Heartbeat
|
||||
if let Some(sessions) = sessions_opt.as_mut() {
|
||||
for session in sessions.iter_mut() {
|
||||
|
|
@ -612,17 +703,33 @@ impl Bridge {
|
|||
handshake_pad_max: secrets.handshake_pad_max,
|
||||
})?;
|
||||
|
||||
let addr = self.local_bind_addr.parse::<std::net::SocketAddr>().map_err(|e| anyhow::anyhow!("invalid bind addr: {}", e))?;
|
||||
let domain = if addr.is_ipv6() { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 };
|
||||
let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?;
|
||||
let _ = sock.set_recv_buffer_size(33554432); // 32MB
|
||||
let _ = sock.set_send_buffer_size(33554432); // 32MB
|
||||
let actual_recv = sock.recv_buffer_size().unwrap_or(0);
|
||||
let actual_send = sock.send_buffer_size().unwrap_or(0);
|
||||
tracing::info!("UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024);
|
||||
sock.bind(&addr.into())?;
|
||||
sock.set_nonblocking(true)?;
|
||||
let socket = UdpSocket::from_std(sock.into())?;
|
||||
let resolved_addrs: Vec<std::net::SocketAddr> = match tokio::net::lookup_host(&self.server_addr).await {
|
||||
Ok(addrs) => addrs.collect(),
|
||||
Err(e) => return Err(anyhow::anyhow!("failed to resolve server address {}: {}", self.server_addr, e)),
|
||||
};
|
||||
let target_addr = resolved_addrs.first().ok_or_else(|| anyhow::anyhow!("no IP addresses resolved for {}", self.server_addr))?;
|
||||
let target_ip = target_addr.ip();
|
||||
let port = target_addr.port();
|
||||
|
||||
tx.send(UiEvent::Log(format!("Connecting to remote server: {}...", target_addr))).await.ok();
|
||||
|
||||
let socket = match self.try_connect_socket(target_ip, port).await {
|
||||
Ok(sock) => sock,
|
||||
Err(e) => {
|
||||
if let std::net::IpAddr::V4(ipv4) = target_ip {
|
||||
tx.send(UiEvent::Log(format!("Direct IPv4 connection failed: {}. Trying NAT64 fallback...", e))).await.ok();
|
||||
let nat64_ipv6 = synthesize_nat64(ipv4);
|
||||
match self.try_connect_socket(std::net::IpAddr::V6(nat64_ipv6), port).await {
|
||||
Ok(sock) => sock,
|
||||
Err(fallback_err) => {
|
||||
return Err(anyhow::anyhow!("Direct IPv4 failed: {}. NAT64 fallback failed: {}", e, fallback_err));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if self.turn_enabled {
|
||||
let turn_addr = if self.turn_server.contains(':') {
|
||||
|
|
@ -635,27 +742,46 @@ impl Bridge {
|
|||
match crate::turn::perform_turn_allocation(&socket, &turn_addr, &self.turn_username, &self.turn_password, &self.server_addr).await {
|
||||
Ok(relay_addr) => {
|
||||
tx.send(UiEvent::Log(format!("TURN relay allocated ({})", relay_addr))).await.ok();
|
||||
// Re-connect the UDP socket to the TURN server so all sends go through it.
|
||||
// The TURN server forwards ChannelData to the OSTP server transparently.
|
||||
|
||||
let resolved_turn: Vec<std::net::SocketAddr> = tokio::net::lookup_host(&turn_addr).await
|
||||
.map_err(|e| anyhow::anyhow!("failed to resolve TURN {}: {}", turn_addr, e))?
|
||||
.collect();
|
||||
let turn_target = resolved_turn.first().ok_or_else(|| anyhow::anyhow!("no IP resolved for TURN {}", turn_addr))?;
|
||||
|
||||
let connect_ip = if socket.local_addr().map(|a| a.is_ipv6()).unwrap_or(false) && turn_target.is_ipv4() {
|
||||
if let std::net::IpAddr::V4(ipv4) = turn_target.ip() {
|
||||
std::net::IpAddr::V6(synthesize_nat64(ipv4))
|
||||
} else {
|
||||
turn_target.ip()
|
||||
}
|
||||
} else {
|
||||
turn_target.ip()
|
||||
};
|
||||
|
||||
let connect_addr = std::net::SocketAddr::new(connect_ip, turn_target.port());
|
||||
socket
|
||||
.connect(&turn_addr)
|
||||
.connect(connect_addr)
|
||||
.await
|
||||
.with_context(|| format!("failed to re-connect to TURN {}", turn_addr))?;
|
||||
.with_context(|| format!("failed to re-connect to TURN {}", connect_addr))?;
|
||||
}
|
||||
Err(e) => {
|
||||
tx.send(UiEvent::Log(format!("TURN allocation failed: {}. Using direct UDP.", e))).await.ok();
|
||||
socket
|
||||
.connect(&self.server_addr)
|
||||
.await
|
||||
.with_context(|| format!("failed to connect udp to {}", self.server_addr))?;
|
||||
}
|
||||
let connect_ip = if socket.local_addr().map(|a| a.is_ipv6()).unwrap_or(false) && target_ip.is_ipv4() {
|
||||
if let std::net::IpAddr::V4(ipv4) = target_ip {
|
||||
std::net::IpAddr::V6(synthesize_nat64(ipv4))
|
||||
} else {
|
||||
target_ip
|
||||
}
|
||||
} else {
|
||||
tx.send(UiEvent::Log(format!("Connected to {}", self.server_addr))).await.ok();
|
||||
target_ip
|
||||
};
|
||||
let connect_addr = std::net::SocketAddr::new(connect_ip, port);
|
||||
socket
|
||||
.connect(&self.server_addr)
|
||||
.connect(connect_addr)
|
||||
.await
|
||||
.with_context(|| format!("failed to connect udp to {}", self.server_addr))?;
|
||||
.with_context(|| format!("failed to connect udp to {}", connect_addr))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Connection to remote is handled inside the TURN/direct branches above
|
||||
|
|
@ -666,16 +792,65 @@ impl Bridge {
|
|||
ProtocolAction::SendDatagram(frame) => frame,
|
||||
_ => anyhow::bail!("protocol did not emit handshake datagram"),
|
||||
};
|
||||
let mut buf = vec![0_u8; 4096];
|
||||
let mut size = 0;
|
||||
let mut success = false;
|
||||
|
||||
// Retransmit handshake up to 4 times with 1200ms timeout to survive packet loss on mobile
|
||||
for attempt in 0..4 {
|
||||
if attempt > 0 {
|
||||
tx.send(UiEvent::Log(format!("Handshake attempt {} lost. Retransmitting...", attempt))).await.ok();
|
||||
}
|
||||
send_datagram(&socket, &handshake_frame, self.turn_enabled).await?;
|
||||
self.metrics.bytes_sent.fetch_add(handshake_frame.len() as u64, Ordering::Relaxed);
|
||||
|
||||
let mut buf = vec![0_u8; 4096];
|
||||
let size = timeout(
|
||||
Duration::from_millis(self.handshake_timeout_ms.max(1)),
|
||||
socket.recv(&mut buf),
|
||||
)
|
||||
.await
|
||||
.context("handshake timeout waiting server response")??;
|
||||
match timeout(Duration::from_millis(1200), socket.recv(&mut buf)).await {
|
||||
Ok(Ok(n)) => {
|
||||
size = n;
|
||||
success = true;
|
||||
break;
|
||||
}
|
||||
_ => {} // retry on timeout or error
|
||||
}
|
||||
}
|
||||
|
||||
let (final_socket, size) = if success {
|
||||
(socket, size)
|
||||
} else {
|
||||
if let std::net::IpAddr::V4(ipv4) = target_ip {
|
||||
tx.send(UiEvent::Log("Direct IPv4 handshake timed out. Trying NAT64 fallback...".to_string())).await.ok();
|
||||
let nat64_ipv6 = synthesize_nat64(ipv4);
|
||||
match self.try_connect_socket(std::net::IpAddr::V6(nat64_ipv6), port).await {
|
||||
Ok(fallback_socket) => {
|
||||
let mut fallback_success = false;
|
||||
for attempt in 0..4 {
|
||||
if attempt > 0 {
|
||||
tx.send(UiEvent::Log(format!("NAT64 handshake attempt {} lost. Retransmitting...", attempt))).await.ok();
|
||||
}
|
||||
send_datagram(&fallback_socket, &handshake_frame, self.turn_enabled).await?;
|
||||
match timeout(Duration::from_millis(1200), fallback_socket.recv(&mut buf)).await {
|
||||
Ok(Ok(n)) => {
|
||||
size = n;
|
||||
fallback_success = true;
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if fallback_success {
|
||||
tx.send(UiEvent::Log("NAT64 fallback handshake successful!".to_string())).await.ok();
|
||||
(fallback_socket, size)
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("NAT64 handshake failed after 3 attempts"));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(anyhow::anyhow!("NAT64 fallback socket creation failed: {}", e)),
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Direct handshake failed after 3 attempts"));
|
||||
}
|
||||
};
|
||||
let socket = final_socket;
|
||||
self.metrics.bytes_recv.fetch_add(size as u64, Ordering::Relaxed);
|
||||
tracing::info!("Handshake response received: {} bytes", size);
|
||||
|
||||
|
|
@ -711,6 +886,39 @@ impl Bridge {
|
|||
self.mux_enabled = cfg.multiplex.enabled;
|
||||
self.mux_sessions = cfg.multiplex.sessions.max(1);
|
||||
}
|
||||
|
||||
async fn try_connect_socket(
|
||||
&self,
|
||||
target_ip: std::net::IpAddr,
|
||||
port: u16,
|
||||
) -> Result<UdpSocket> {
|
||||
let is_ipv6 = target_ip.is_ipv6();
|
||||
let domain = if is_ipv6 { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 };
|
||||
let bind_addr = if is_ipv6 {
|
||||
std::net::SocketAddr::new(std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), 0)
|
||||
} else {
|
||||
std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0)
|
||||
};
|
||||
|
||||
let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?;
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::io::AsRawFd;
|
||||
protect_socket(sock.as_raw_fd());
|
||||
}
|
||||
let _ = sock.set_recv_buffer_size(33554432); // 32MB
|
||||
let _ = sock.set_send_buffer_size(33554432); // 32MB
|
||||
let actual_recv = sock.recv_buffer_size().unwrap_or(0);
|
||||
let actual_send = sock.send_buffer_size().unwrap_or(0);
|
||||
tracing::info!("UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024);
|
||||
sock.bind(&bind_addr.into())?;
|
||||
sock.set_nonblocking(true)?;
|
||||
let socket = UdpSocket::from_std(sock.into())?;
|
||||
|
||||
let connect_addr = std::net::SocketAddr::new(target_ip, port);
|
||||
socket.connect(connect_addr).await.with_context(|| format!("failed to connect udp to {}", connect_addr))?;
|
||||
Ok(socket)
|
||||
}
|
||||
}
|
||||
|
||||
fn next_profile(current: TrafficProfile) -> TrafficProfile {
|
||||
|
|
@ -721,3 +929,12 @@ fn next_profile(current: TrafficProfile) -> TrafficProfile {
|
|||
}
|
||||
}
|
||||
|
||||
fn synthesize_nat64(ip: std::net::Ipv4Addr) -> std::net::Ipv6Addr {
|
||||
let octets = ip.octets();
|
||||
std::net::Ipv6Addr::new(
|
||||
0x0064, 0xff9b, 0, 0, 0, 0,
|
||||
((octets[0] as u16) << 8) | octets[1] as u16,
|
||||
((octets[2] as u16) << 8) | octets[3] as u16,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -123,6 +123,21 @@ fn extract_host_port(uri: &str, default_port: u16) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
struct StreamGuard {
|
||||
stream_id: u16,
|
||||
close_tx: mpsc::Sender<u16>,
|
||||
}
|
||||
|
||||
impl Drop for StreamGuard {
|
||||
fn drop(&mut self) {
|
||||
let tx = self.close_tx.clone();
|
||||
let id = self.stream_id;
|
||||
tokio::spawn(async move {
|
||||
let _ = tx.send(id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_proxy_client(
|
||||
mut client: TcpStream,
|
||||
stream_id: u16,
|
||||
|
|
@ -133,6 +148,8 @@ async fn handle_proxy_client(
|
|||
debug: bool,
|
||||
matcher: ExclusionMatcher,
|
||||
) -> Result<()> {
|
||||
let _guard = StreamGuard { stream_id, close_tx: close_tx.clone() };
|
||||
|
||||
// Peek the first byte to distinguish SOCKS5 (0x05) from HTTP (any printable ASCII)
|
||||
let mut first_byte = [0_u8; 1];
|
||||
client.read_exact(&mut first_byte).await?;
|
||||
|
|
|
|||
|
|
@ -87,7 +87,10 @@ pub async fn run_wintun_tunnel(
|
|||
let setup_script = format!(
|
||||
"$remote_ip = '{}'\n\
|
||||
$exe_path = '{}'\n\
|
||||
$route = Find-NetRoute -RemoteIPAddress $remote_ip -ErrorAction SilentlyContinue | Select-Object -First 1\n\
|
||||
if (-not $route) {{\n\
|
||||
$route = Get-NetRoute -DestinationPrefix '0.0.0.0/0' | Where-Object {{ $_.InterfaceAlias -notmatch 'tun' -and $_.InterfaceAlias -notmatch 'wintun' }} | Sort-Object RouteMetric | Select-Object -First 1\n\
|
||||
}}\n\
|
||||
$gw = $route.NextHop\n\
|
||||
$ifIndex = $route.InterfaceIndex\n\
|
||||
New-NetRoute -DestinationPrefix \"$remote_ip/32\" -NextHop $gw -InterfaceIndex $ifIndex -RouteMetric 1 -ErrorAction SilentlyContinue\n\
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ struct SdkState {
|
|||
runtime: Option<Runtime>,
|
||||
shutdown_tx: Option<watch::Sender<bool>>,
|
||||
metrics: Option<Arc<BridgeMetrics>>,
|
||||
tun_child: Option<std::process::Child>,
|
||||
cmd_tx: Option<mpsc::Sender<BridgeCommand>>,
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
|
|
@ -22,8 +24,12 @@ lazy_static! {
|
|||
runtime: None,
|
||||
shutdown_tx: None,
|
||||
metrics: None,
|
||||
tun_child: None,
|
||||
cmd_tx: None,
|
||||
});
|
||||
static ref LOGS: Mutex<VecDeque<String>> = Mutex::new(VecDeque::new());
|
||||
static ref JVM: Mutex<Option<jni::JavaVM>> = Mutex::new(None);
|
||||
static ref CLASS_REF: Mutex<Option<jni::objects::GlobalRef>> = Mutex::new(None);
|
||||
}
|
||||
|
||||
fn add_log(text: String) {
|
||||
|
|
@ -40,6 +46,9 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
mut env: JNIEnv,
|
||||
_class: JClass,
|
||||
config_json: JString,
|
||||
fd: jni::sys::jint,
|
||||
t2s_bin_path: JString,
|
||||
local_proxy: JString,
|
||||
) -> jboolean {
|
||||
let mut state = match STATE.lock() {
|
||||
Ok(s) => s,
|
||||
|
|
@ -51,11 +60,61 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
return jni::sys::JNI_TRUE;
|
||||
}
|
||||
|
||||
if let Ok(jvm) = env.get_java_vm() {
|
||||
if let Ok(mut guard) = JVM.lock() {
|
||||
*guard = Some(jvm);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(cls) = env.find_class("net/ostp/client/OstpClientSdk") {
|
||||
if let Ok(global_cls) = env.new_global_ref(cls) {
|
||||
if let Ok(mut guard) = CLASS_REF.lock() {
|
||||
*guard = Some(global_cls);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ostp_client::bridge::set_socket_protector(|fd| {
|
||||
let jvm_guard = match JVM.lock() {
|
||||
Ok(g) => g,
|
||||
Err(_) => return false,
|
||||
};
|
||||
let class_guard = match CLASS_REF.lock() {
|
||||
Ok(g) => g,
|
||||
Err(_) => return false,
|
||||
};
|
||||
if let (Some(ref jvm), Some(ref class_ref)) = (&*jvm_guard, &*class_guard) {
|
||||
if let Ok(mut env) = jvm.attach_current_thread() {
|
||||
let class_obj = unsafe { jni::objects::JClass::from_raw(class_ref.as_obj().as_raw()) };
|
||||
let val = env.call_static_method(
|
||||
&class_obj,
|
||||
"protectSocket",
|
||||
"(I)Z",
|
||||
&[jni::objects::JValue::from(fd)],
|
||||
);
|
||||
if let Ok(jval) = val {
|
||||
return jval.z().unwrap_or(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
});
|
||||
|
||||
let config_str: String = match env.get_string(&config_json) {
|
||||
Ok(s) => s.into(),
|
||||
Err(_) => return jni::sys::JNI_FALSE,
|
||||
};
|
||||
|
||||
let t2s_path: String = match env.get_string(&t2s_bin_path) {
|
||||
Ok(s) => s.into(),
|
||||
Err(_) => return jni::sys::JNI_FALSE,
|
||||
};
|
||||
|
||||
let proxy_addr: String = match env.get_string(&local_proxy) {
|
||||
Ok(s) => s.into(),
|
||||
Err(_) => return jni::sys::JNI_FALSE,
|
||||
};
|
||||
|
||||
// Parse config from JSON
|
||||
let config: ClientConfig = match serde_json::from_str(&config_str) {
|
||||
Ok(cfg) => cfg,
|
||||
|
|
@ -65,6 +124,8 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
}
|
||||
};
|
||||
|
||||
let debug = config.debug;
|
||||
|
||||
// Create tokio runtime
|
||||
let rt = match Runtime::new() {
|
||||
Ok(r) => r,
|
||||
|
|
@ -135,9 +196,64 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
let _ = cmd_tx_clone.send(BridgeCommand::ToggleTunnel).await;
|
||||
});
|
||||
|
||||
// Spawn tun2socks
|
||||
let fd_str = format!("fd://{}", fd);
|
||||
let proxy_str = format!("socks5://{}", proxy_addr);
|
||||
|
||||
if debug {
|
||||
add_log(format!("Spawning tun2socks: {} -device {} -proxy {}", t2s_path, fd_str, proxy_str));
|
||||
}
|
||||
|
||||
let mut cmd = std::process::Command::new(&t2s_path);
|
||||
cmd.arg("-device")
|
||||
.arg(&fd_str)
|
||||
.arg("-proxy")
|
||||
.arg(&proxy_str)
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = match cmd.spawn() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
add_log(format!("Failed to spawn tun2socks from Rust: {e}"));
|
||||
return jni::sys::JNI_FALSE;
|
||||
}
|
||||
};
|
||||
|
||||
let stdout = child.stdout.take().unwrap();
|
||||
let stderr = child.stderr.take().unwrap();
|
||||
|
||||
// Read stdout
|
||||
std::thread::spawn(move || {
|
||||
use std::io::{BufRead, BufReader};
|
||||
let reader = BufReader::new(stdout);
|
||||
for line in reader.lines() {
|
||||
if let Ok(l) = line {
|
||||
if debug {
|
||||
add_log(format!("tun2socks: {}", l));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Read stderr & wait
|
||||
std::thread::spawn(move || {
|
||||
use std::io::{BufRead, BufReader};
|
||||
let reader = BufReader::new(stderr);
|
||||
for line in reader.lines() {
|
||||
if let Ok(l) = line {
|
||||
if debug {
|
||||
add_log(format!("tun2socks ERROR: {}", l));
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
state.runtime = Some(rt);
|
||||
state.shutdown_tx = Some(shutdown_tx);
|
||||
state.metrics = Some(metrics_clone);
|
||||
state.tun_child = Some(child);
|
||||
state.cmd_tx = Some(cmd_tx);
|
||||
|
||||
add_log("OSTP SDK: Client successfully started".to_string());
|
||||
jni::sys::JNI_TRUE
|
||||
|
|
@ -153,6 +269,11 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
|
|||
Err(_) => return jni::sys::JNI_FALSE,
|
||||
};
|
||||
|
||||
if let Some(mut child) = state.tun_child.take() {
|
||||
let _ = child.kill();
|
||||
add_log("Killed tun2socks process".to_string());
|
||||
}
|
||||
|
||||
if let Some(shutdown_tx) = state.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(true);
|
||||
}
|
||||
|
|
@ -161,6 +282,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
|
|||
rt.shutdown_timeout(std::time::Duration::from_secs(3));
|
||||
}
|
||||
|
||||
state.cmd_tx = None;
|
||||
state.metrics = None;
|
||||
add_log("OSTP SDK: Client successfully stopped".to_string());
|
||||
jni::sys::JNI_TRUE
|
||||
|
|
@ -182,13 +304,17 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics(
|
|||
if let Some(m) = &state.metrics {
|
||||
let sent = m.bytes_sent.load(Ordering::Relaxed);
|
||||
let recv = m.bytes_recv.load(Ordering::Relaxed);
|
||||
let json = format!(r#"{{"bytes_sent": {}, "bytes_recv": {}}}"#, sent, recv);
|
||||
let conn_state = m.connection_state.load(Ordering::Relaxed);
|
||||
let json = format!(
|
||||
r#"{{"bytes_sent": {}, "bytes_recv": {}, "connection_state": {}}}"#,
|
||||
sent, recv, conn_state
|
||||
);
|
||||
match env.new_string(json) {
|
||||
Ok(s) => s.into_raw(),
|
||||
Err(_) => std::ptr::null_mut(),
|
||||
}
|
||||
} else {
|
||||
match env.new_string(r#"{"bytes_sent": 0, "bytes_recv": 0}"#) {
|
||||
match env.new_string(r#"{"bytes_sent": 0, "bytes_recv": 0, "connection_state": 0}"#) {
|
||||
Ok(s) => s.into_raw(),
|
||||
Err(_) => std::ptr::null_mut(),
|
||||
}
|
||||
|
|
@ -215,3 +341,15 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs(
|
|||
Err(_) => std::ptr::null_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_addLog(
|
||||
mut env: JNIEnv,
|
||||
_class: JClass,
|
||||
log_msg: JString,
|
||||
) {
|
||||
if let Ok(s) = env.get_string(&log_msg) {
|
||||
let text: String = s.into();
|
||||
add_log(text);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue