diff --git a/.gitignore b/.gitignore index e963be2..b780b72 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/ostp-client/src/bridge.rs b/ostp-client/src/bridge.rs index 3de949e..0a208cf 100644 --- a/ostp-client/src/bridge.rs +++ b/ostp-client/src/bridge.rs @@ -16,6 +16,22 @@ use crate::app::{BridgeCommand, ConnectionStatus, UiEvent}; use crate::config::ClientConfig; use crate::tunnel::{ProxyEvent, ProxyToClientMsg}; +static SOCKET_PROTECTOR: std::sync::OnceLock bool + Send + Sync>> = std::sync::OnceLock::new(); + +pub fn set_socket_protector(f: F) +where + F: Fn(i32) -> bool + Send + Sync + 'static, +{ + let _ = SOCKET_PROTECTOR.set(Box::new(f)); +} + +pub fn protect_socket(fd: i32) -> bool { + if let Some(f) = SOCKET_PROTECTOR.get() { + return f(fd); + } + true +} + pub struct BridgeMetrics { pub bytes_sent: AtomicU64, pub bytes_recv: AtomicU64, @@ -238,12 +254,13 @@ impl Bridge { let (udp_tx, udp_rx) = mpsc::channel(100000); // Increased for high-speed traffic stability let mut sessions = Vec::with_capacity(session_count); let mut rtt_sum = 0.0; + let mut successful_sessions = 0; - let mut handshake_error = None; for idx in 0..session_count { let session_id: u32 = rand::thread_rng().gen(); match self.perform_handshake_with_id(&tx, session_id).await { Ok((sock, mach, rtt)) => { + let session_index = sessions.len(); let socket = Arc::new(sock); let socket_clone = socket.clone(); let udp_tx_clone = udp_tx.clone(); @@ -251,42 +268,46 @@ impl Bridge { tokio::spawn(async move { let mut buf = vec![0_u8; 65535]; loop { - match socket_clone.recv(&mut buf).await { - Ok(n) => { - let inbound = if is_turn && n >= 4 && buf[0] == 0x40 && buf[1] == 0x00 { - let len = u16::from_be_bytes([buf[2], buf[3]]) as usize; - if 4 + len <= n { - Bytes::copy_from_slice(&buf[4..4+len]) - } else { - Bytes::copy_from_slice(&buf[..n]) - } - } else { - Bytes::copy_from_slice(&buf[..n]) - }; - if udp_tx_clone.send((idx, inbound)).await.is_err() { - break; - } - } - Err(_) => { - break; - } - } + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = if is_turn && n >= 4 && buf[0] == 0x40 && buf[1] == 0x00 { + let len = u16::from_be_bytes([buf[2], buf[3]]) as usize; + if 4 + len <= n { + Bytes::copy_from_slice(&buf[4..4+len]) + } else { + Bytes::copy_from_slice(&buf[..n]) + } + } else { + Bytes::copy_from_slice(&buf[..n]) + }; + if udp_tx_clone.send((session_index, inbound)).await.is_err() { + break; + } + } + Err(e) => { + // Under Windows/Winsock, transient UDP socket errors (like WSAECONNRESET) are returned + // as Err(ConnectionReset). We MUST NOT break the loop on transient errors, otherwise the + // download path will be permanently killed while the upload path keeps running. + tracing::warn!("UDP socket recv error (session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } } }); sessions.push(SessionState { socket, machine: mach }); rtt_sum += rtt; + successful_sessions += 1; } Err(err) => { - handshake_error = Some(err); - break; + tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok(); } } } - if let Some(err) = handshake_error { + if sessions.is_empty() { _proxy_guard = None; - tx.send(UiEvent::Log(format!("Connection failed: {err}"))).await.ok(); + tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok(); tx.send(UiEvent::TunnelStopped).await.ok(); self.metrics.connection_state.store(0, Ordering::Relaxed); continue; @@ -294,7 +315,7 @@ impl Bridge { udp_rx_opt = Some(udp_rx); sessions_opt = Some(sessions); - self.last_rtt_ms = rtt_sum / session_count as f64; + self.last_rtt_ms = rtt_sum / successful_sessions as f64; self.running = true; self.last_sample_at = Instant::now(); self.last_valid_recv = Instant::now(); @@ -352,17 +373,87 @@ impl Bridge { } _ = keepalive_tick.tick() => { if self.running { - // 1. Connection Liveness Check - if self.last_valid_recv.elapsed().as_secs() > 60 { - let _ = tx.send(UiEvent::Log("Connection lost (timeout). Reconnecting...".into())).await; - self.running = false; - _proxy_guard = None; - sessions_opt = None; - stream_map.clear(); - self.reset_proxy_streams(&tx, &proxy_tx, "keepalive timeout"); - let _ = tx.send(UiEvent::TunnelStopped).await; - self.metrics.connection_state.store(0, Ordering::Relaxed); - continue; + // 1. Connection Liveness Check & Silent Background Reconnect + if self.last_valid_recv.elapsed().as_secs() > 25 { + let elapsed = self.last_valid_recv.elapsed().as_secs(); + if elapsed > 180 { + // Hard timeout after 3 minutes of total silence + let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await; + self.running = false; + _proxy_guard = None; + sessions_opt = None; + stream_map.clear(); + self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout"); + let _ = tx.send(UiEvent::TunnelStopped).await; + self.metrics.connection_state.store(0, Ordering::Relaxed); + continue; + } + + let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await; + self.metrics.connection_state.store(1, Ordering::Relaxed); // State: Connecting (Handshake) + + let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 }; + let (udp_tx, udp_rx) = mpsc::channel(100000); + let mut new_sessions = Vec::with_capacity(session_count); + let mut successful_sessions = 0; + let mut rtt_sum = 0.0; + + for idx in 0..session_count { + let session_id: u32 = rand::thread_rng().gen(); + match self.perform_handshake_with_id(&tx, session_id).await { + Ok((sock, mach, rtt)) => { + let session_index = new_sessions.len(); + let socket = Arc::new(sock); + let socket_clone = socket.clone(); + let udp_tx_clone = udp_tx.clone(); + let is_turn = self.turn_enabled; + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match socket_clone.recv(&mut buf).await { + Ok(n) => { + let inbound = if is_turn && n >= 4 && buf[0] == 0x40 && buf[1] == 0x00 { + let len = u16::from_be_bytes([buf[2], buf[3]]) as usize; + if 4 + len <= n { + Bytes::copy_from_slice(&buf[4..4+len]) + } else { + Bytes::copy_from_slice(&buf[..n]) + } + } else { + Bytes::copy_from_slice(&buf[..n]) + }; + if udp_tx_clone.send((session_index, inbound)).await.is_err() { + break; + } + } + Err(e) => { + tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + } + } + }); + + new_sessions.push(SessionState { socket, machine: mach }); + rtt_sum += rtt; + successful_sessions += 1; + } + Err(err) => { + let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await; + } + } + } + + if !new_sessions.is_empty() { + sessions_opt = Some(new_sessions); + udp_rx_opt = Some(udp_rx); + self.last_rtt_ms = rtt_sum / successful_sessions as f64; + self.last_valid_recv = Instant::now(); + self.metrics.connection_state.store(2, Ordering::Relaxed); // State: Connected + let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await; + } else { + let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await; + } } // 2. Active Keep-Alive / Heartbeat @@ -612,17 +703,33 @@ impl Bridge { handshake_pad_max: secrets.handshake_pad_max, })?; - let addr = self.local_bind_addr.parse::().map_err(|e| anyhow::anyhow!("invalid bind addr: {}", e))?; - let domain = if addr.is_ipv6() { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 }; - let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?; - let _ = sock.set_recv_buffer_size(33554432); // 32MB - let _ = sock.set_send_buffer_size(33554432); // 32MB - let actual_recv = sock.recv_buffer_size().unwrap_or(0); - let actual_send = sock.send_buffer_size().unwrap_or(0); - tracing::info!("UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024); - sock.bind(&addr.into())?; - sock.set_nonblocking(true)?; - let socket = UdpSocket::from_std(sock.into())?; + let resolved_addrs: Vec = match tokio::net::lookup_host(&self.server_addr).await { + Ok(addrs) => addrs.collect(), + Err(e) => return Err(anyhow::anyhow!("failed to resolve server address {}: {}", self.server_addr, e)), + }; + let target_addr = resolved_addrs.first().ok_or_else(|| anyhow::anyhow!("no IP addresses resolved for {}", self.server_addr))?; + let target_ip = target_addr.ip(); + let port = target_addr.port(); + + tx.send(UiEvent::Log(format!("Connecting to remote server: {}...", target_addr))).await.ok(); + + let socket = match self.try_connect_socket(target_ip, port).await { + Ok(sock) => sock, + Err(e) => { + if let std::net::IpAddr::V4(ipv4) = target_ip { + tx.send(UiEvent::Log(format!("Direct IPv4 connection failed: {}. Trying NAT64 fallback...", e))).await.ok(); + let nat64_ipv6 = synthesize_nat64(ipv4); + match self.try_connect_socket(std::net::IpAddr::V6(nat64_ipv6), port).await { + Ok(sock) => sock, + Err(fallback_err) => { + return Err(anyhow::anyhow!("Direct IPv4 failed: {}. NAT64 fallback failed: {}", e, fallback_err)); + } + } + } else { + return Err(e); + } + } + }; if self.turn_enabled { let turn_addr = if self.turn_server.contains(':') { @@ -635,27 +742,46 @@ impl Bridge { match crate::turn::perform_turn_allocation(&socket, &turn_addr, &self.turn_username, &self.turn_password, &self.server_addr).await { Ok(relay_addr) => { tx.send(UiEvent::Log(format!("TURN relay allocated ({})", relay_addr))).await.ok(); - // Re-connect the UDP socket to the TURN server so all sends go through it. - // The TURN server forwards ChannelData to the OSTP server transparently. + + let resolved_turn: Vec = tokio::net::lookup_host(&turn_addr).await + .map_err(|e| anyhow::anyhow!("failed to resolve TURN {}: {}", turn_addr, e))? + .collect(); + let turn_target = resolved_turn.first().ok_or_else(|| anyhow::anyhow!("no IP resolved for TURN {}", turn_addr))?; + + let connect_ip = if socket.local_addr().map(|a| a.is_ipv6()).unwrap_or(false) && turn_target.is_ipv4() { + if let std::net::IpAddr::V4(ipv4) = turn_target.ip() { + std::net::IpAddr::V6(synthesize_nat64(ipv4)) + } else { + turn_target.ip() + } + } else { + turn_target.ip() + }; + + let connect_addr = std::net::SocketAddr::new(connect_ip, turn_target.port()); socket - .connect(&turn_addr) + .connect(connect_addr) .await - .with_context(|| format!("failed to re-connect to TURN {}", turn_addr))?; + .with_context(|| format!("failed to re-connect to TURN {}", connect_addr))?; } Err(e) => { tx.send(UiEvent::Log(format!("TURN allocation failed: {}. Using direct UDP.", e))).await.ok(); + let connect_ip = if socket.local_addr().map(|a| a.is_ipv6()).unwrap_or(false) && target_ip.is_ipv4() { + if let std::net::IpAddr::V4(ipv4) = target_ip { + std::net::IpAddr::V6(synthesize_nat64(ipv4)) + } else { + target_ip + } + } else { + target_ip + }; + let connect_addr = std::net::SocketAddr::new(connect_ip, port); socket - .connect(&self.server_addr) + .connect(connect_addr) .await - .with_context(|| format!("failed to connect udp to {}", self.server_addr))?; + .with_context(|| format!("failed to connect udp to {}", connect_addr))?; } } - } else { - tx.send(UiEvent::Log(format!("Connected to {}", self.server_addr))).await.ok(); - socket - .connect(&self.server_addr) - .await - .with_context(|| format!("failed to connect udp to {}", self.server_addr))?; } // Connection to remote is handled inside the TURN/direct branches above @@ -666,16 +792,65 @@ impl Bridge { ProtocolAction::SendDatagram(frame) => frame, _ => anyhow::bail!("protocol did not emit handshake datagram"), }; - send_datagram(&socket, &handshake_frame, self.turn_enabled).await?; - self.metrics.bytes_sent.fetch_add(handshake_frame.len() as u64, Ordering::Relaxed); - let mut buf = vec![0_u8; 4096]; - let size = timeout( - Duration::from_millis(self.handshake_timeout_ms.max(1)), - socket.recv(&mut buf), - ) - .await - .context("handshake timeout waiting server response")??; + let mut size = 0; + let mut success = false; + + // Retransmit handshake up to 4 times with 1200ms timeout to survive packet loss on mobile + for attempt in 0..4 { + if attempt > 0 { + tx.send(UiEvent::Log(format!("Handshake attempt {} lost. Retransmitting...", attempt))).await.ok(); + } + send_datagram(&socket, &handshake_frame, self.turn_enabled).await?; + self.metrics.bytes_sent.fetch_add(handshake_frame.len() as u64, Ordering::Relaxed); + + match timeout(Duration::from_millis(1200), socket.recv(&mut buf)).await { + Ok(Ok(n)) => { + size = n; + success = true; + break; + } + _ => {} // retry on timeout or error + } + } + + let (final_socket, size) = if success { + (socket, size) + } else { + if let std::net::IpAddr::V4(ipv4) = target_ip { + tx.send(UiEvent::Log("Direct IPv4 handshake timed out. Trying NAT64 fallback...".to_string())).await.ok(); + let nat64_ipv6 = synthesize_nat64(ipv4); + match self.try_connect_socket(std::net::IpAddr::V6(nat64_ipv6), port).await { + Ok(fallback_socket) => { + let mut fallback_success = false; + for attempt in 0..4 { + if attempt > 0 { + tx.send(UiEvent::Log(format!("NAT64 handshake attempt {} lost. Retransmitting...", attempt))).await.ok(); + } + send_datagram(&fallback_socket, &handshake_frame, self.turn_enabled).await?; + match timeout(Duration::from_millis(1200), fallback_socket.recv(&mut buf)).await { + Ok(Ok(n)) => { + size = n; + fallback_success = true; + break; + } + _ => {} + } + } + if fallback_success { + tx.send(UiEvent::Log("NAT64 fallback handshake successful!".to_string())).await.ok(); + (fallback_socket, size) + } else { + return Err(anyhow::anyhow!("NAT64 handshake failed after 3 attempts")); + } + } + Err(e) => return Err(anyhow::anyhow!("NAT64 fallback socket creation failed: {}", e)), + } + } else { + return Err(anyhow::anyhow!("Direct handshake failed after 3 attempts")); + } + }; + let socket = final_socket; self.metrics.bytes_recv.fetch_add(size as u64, Ordering::Relaxed); tracing::info!("Handshake response received: {} bytes", size); @@ -711,6 +886,39 @@ impl Bridge { self.mux_enabled = cfg.multiplex.enabled; self.mux_sessions = cfg.multiplex.sessions.max(1); } + + async fn try_connect_socket( + &self, + target_ip: std::net::IpAddr, + port: u16, + ) -> Result { + let is_ipv6 = target_ip.is_ipv6(); + let domain = if is_ipv6 { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 }; + let bind_addr = if is_ipv6 { + std::net::SocketAddr::new(std::net::IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED), 0) + } else { + std::net::SocketAddr::new(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED), 0) + }; + + let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?; + #[cfg(unix)] + { + use std::os::unix::io::AsRawFd; + protect_socket(sock.as_raw_fd()); + } + let _ = sock.set_recv_buffer_size(33554432); // 32MB + let _ = sock.set_send_buffer_size(33554432); // 32MB + let actual_recv = sock.recv_buffer_size().unwrap_or(0); + let actual_send = sock.send_buffer_size().unwrap_or(0); + tracing::info!("UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024); + sock.bind(&bind_addr.into())?; + sock.set_nonblocking(true)?; + let socket = UdpSocket::from_std(sock.into())?; + + let connect_addr = std::net::SocketAddr::new(target_ip, port); + socket.connect(connect_addr).await.with_context(|| format!("failed to connect udp to {}", connect_addr))?; + Ok(socket) + } } fn next_profile(current: TrafficProfile) -> TrafficProfile { @@ -721,3 +929,12 @@ fn next_profile(current: TrafficProfile) -> TrafficProfile { } } +fn synthesize_nat64(ip: std::net::Ipv4Addr) -> std::net::Ipv6Addr { + let octets = ip.octets(); + std::net::Ipv6Addr::new( + 0x0064, 0xff9b, 0, 0, 0, 0, + ((octets[0] as u16) << 8) | octets[1] as u16, + ((octets[2] as u16) << 8) | octets[3] as u16, + ) +} + diff --git a/ostp-client/src/tunnel/proxy.rs b/ostp-client/src/tunnel/proxy.rs index bac23b3..6c8731b 100644 --- a/ostp-client/src/tunnel/proxy.rs +++ b/ostp-client/src/tunnel/proxy.rs @@ -123,6 +123,21 @@ fn extract_host_port(uri: &str, default_port: u16) -> String { } } +struct StreamGuard { + stream_id: u16, + close_tx: mpsc::Sender, +} + +impl Drop for StreamGuard { + fn drop(&mut self) { + let tx = self.close_tx.clone(); + let id = self.stream_id; + tokio::spawn(async move { + let _ = tx.send(id).await; + }); + } +} + async fn handle_proxy_client( mut client: TcpStream, stream_id: u16, @@ -133,6 +148,8 @@ async fn handle_proxy_client( debug: bool, matcher: ExclusionMatcher, ) -> Result<()> { + let _guard = StreamGuard { stream_id, close_tx: close_tx.clone() }; + // Peek the first byte to distinguish SOCKS5 (0x05) from HTTP (any printable ASCII) let mut first_byte = [0_u8; 1]; client.read_exact(&mut first_byte).await?; diff --git a/ostp-client/src/tunnel/wintun_handler.rs b/ostp-client/src/tunnel/wintun_handler.rs index 01a141d..7c5d5d9 100644 --- a/ostp-client/src/tunnel/wintun_handler.rs +++ b/ostp-client/src/tunnel/wintun_handler.rs @@ -87,7 +87,10 @@ pub async fn run_wintun_tunnel( let setup_script = format!( "$remote_ip = '{}'\n\ $exe_path = '{}'\n\ - $route = Get-NetRoute -DestinationPrefix '0.0.0.0/0' | Where-Object {{ $_.InterfaceAlias -notmatch 'tun' -and $_.InterfaceAlias -notmatch 'wintun' }} | Sort-Object RouteMetric | Select-Object -First 1\n\ + $route = Find-NetRoute -RemoteIPAddress $remote_ip -ErrorAction SilentlyContinue | Select-Object -First 1\n\ + if (-not $route) {{\n\ + $route = Get-NetRoute -DestinationPrefix '0.0.0.0/0' | Where-Object {{ $_.InterfaceAlias -notmatch 'tun' -and $_.InterfaceAlias -notmatch 'wintun' }} | Sort-Object RouteMetric | Select-Object -First 1\n\ + }}\n\ $gw = $route.NextHop\n\ $ifIndex = $route.InterfaceIndex\n\ New-NetRoute -DestinationPrefix \"$remote_ip/32\" -NextHop $gw -InterfaceIndex $ifIndex -RouteMetric 1 -ErrorAction SilentlyContinue\n\ diff --git a/ostp-jni/src/lib.rs b/ostp-jni/src/lib.rs index fb1bf95..3d4119c 100644 --- a/ostp-jni/src/lib.rs +++ b/ostp-jni/src/lib.rs @@ -15,6 +15,8 @@ struct SdkState { runtime: Option, shutdown_tx: Option>, metrics: Option>, + tun_child: Option, + cmd_tx: Option>, } lazy_static! { @@ -22,8 +24,12 @@ lazy_static! { runtime: None, shutdown_tx: None, metrics: None, + tun_child: None, + cmd_tx: None, }); static ref LOGS: Mutex> = Mutex::new(VecDeque::new()); + static ref JVM: Mutex> = Mutex::new(None); + static ref CLASS_REF: Mutex> = Mutex::new(None); } fn add_log(text: String) { @@ -40,6 +46,9 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( mut env: JNIEnv, _class: JClass, config_json: JString, + fd: jni::sys::jint, + t2s_bin_path: JString, + local_proxy: JString, ) -> jboolean { let mut state = match STATE.lock() { Ok(s) => s, @@ -51,11 +60,61 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( return jni::sys::JNI_TRUE; } + if let Ok(jvm) = env.get_java_vm() { + if let Ok(mut guard) = JVM.lock() { + *guard = Some(jvm); + } + } + + if let Ok(cls) = env.find_class("net/ostp/client/OstpClientSdk") { + if let Ok(global_cls) = env.new_global_ref(cls) { + if let Ok(mut guard) = CLASS_REF.lock() { + *guard = Some(global_cls); + } + } + } + + ostp_client::bridge::set_socket_protector(|fd| { + let jvm_guard = match JVM.lock() { + Ok(g) => g, + Err(_) => return false, + }; + let class_guard = match CLASS_REF.lock() { + Ok(g) => g, + Err(_) => return false, + }; + if let (Some(ref jvm), Some(ref class_ref)) = (&*jvm_guard, &*class_guard) { + if let Ok(mut env) = jvm.attach_current_thread() { + let class_obj = unsafe { jni::objects::JClass::from_raw(class_ref.as_obj().as_raw()) }; + let val = env.call_static_method( + &class_obj, + "protectSocket", + "(I)Z", + &[jni::objects::JValue::from(fd)], + ); + if let Ok(jval) = val { + return jval.z().unwrap_or(false); + } + } + } + false + }); + let config_str: String = match env.get_string(&config_json) { Ok(s) => s.into(), Err(_) => return jni::sys::JNI_FALSE, }; + let t2s_path: String = match env.get_string(&t2s_bin_path) { + Ok(s) => s.into(), + Err(_) => return jni::sys::JNI_FALSE, + }; + + let proxy_addr: String = match env.get_string(&local_proxy) { + Ok(s) => s.into(), + Err(_) => return jni::sys::JNI_FALSE, + }; + // Parse config from JSON let config: ClientConfig = match serde_json::from_str(&config_str) { Ok(cfg) => cfg, @@ -65,6 +124,8 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( } }; + let debug = config.debug; + // Create tokio runtime let rt = match Runtime::new() { Ok(r) => r, @@ -135,9 +196,64 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient( let _ = cmd_tx_clone.send(BridgeCommand::ToggleTunnel).await; }); + // Spawn tun2socks + let fd_str = format!("fd://{}", fd); + let proxy_str = format!("socks5://{}", proxy_addr); + + if debug { + add_log(format!("Spawning tun2socks: {} -device {} -proxy {}", t2s_path, fd_str, proxy_str)); + } + + let mut cmd = std::process::Command::new(&t2s_path); + cmd.arg("-device") + .arg(&fd_str) + .arg("-proxy") + .arg(&proxy_str) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()); + + let mut child = match cmd.spawn() { + Ok(c) => c, + Err(e) => { + add_log(format!("Failed to spawn tun2socks from Rust: {e}")); + return jni::sys::JNI_FALSE; + } + }; + + let stdout = child.stdout.take().unwrap(); + let stderr = child.stderr.take().unwrap(); + + // Read stdout + std::thread::spawn(move || { + use std::io::{BufRead, BufReader}; + let reader = BufReader::new(stdout); + for line in reader.lines() { + if let Ok(l) = line { + if debug { + add_log(format!("tun2socks: {}", l)); + } + } + } + }); + + // Read stderr & wait + std::thread::spawn(move || { + use std::io::{BufRead, BufReader}; + let reader = BufReader::new(stderr); + for line in reader.lines() { + if let Ok(l) = line { + if debug { + add_log(format!("tun2socks ERROR: {}", l)); + } + } + } + }); + state.runtime = Some(rt); state.shutdown_tx = Some(shutdown_tx); state.metrics = Some(metrics_clone); + state.tun_child = Some(child); + state.cmd_tx = Some(cmd_tx); add_log("OSTP SDK: Client successfully started".to_string()); jni::sys::JNI_TRUE @@ -153,6 +269,11 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient( Err(_) => return jni::sys::JNI_FALSE, }; + if let Some(mut child) = state.tun_child.take() { + let _ = child.kill(); + add_log("Killed tun2socks process".to_string()); + } + if let Some(shutdown_tx) = state.shutdown_tx.take() { let _ = shutdown_tx.send(true); } @@ -161,6 +282,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient( rt.shutdown_timeout(std::time::Duration::from_secs(3)); } + state.cmd_tx = None; state.metrics = None; add_log("OSTP SDK: Client successfully stopped".to_string()); jni::sys::JNI_TRUE @@ -182,13 +304,17 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics( if let Some(m) = &state.metrics { let sent = m.bytes_sent.load(Ordering::Relaxed); let recv = m.bytes_recv.load(Ordering::Relaxed); - let json = format!(r#"{{"bytes_sent": {}, "bytes_recv": {}}}"#, sent, recv); + let conn_state = m.connection_state.load(Ordering::Relaxed); + let json = format!( + r#"{{"bytes_sent": {}, "bytes_recv": {}, "connection_state": {}}}"#, + sent, recv, conn_state + ); match env.new_string(json) { Ok(s) => s.into_raw(), Err(_) => std::ptr::null_mut(), } } else { - match env.new_string(r#"{"bytes_sent": 0, "bytes_recv": 0}"#) { + match env.new_string(r#"{"bytes_sent": 0, "bytes_recv": 0, "connection_state": 0}"#) { Ok(s) => s.into_raw(), Err(_) => std::ptr::null_mut(), } @@ -215,3 +341,15 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs( Err(_) => std::ptr::null_mut(), } } + +#[no_mangle] +pub extern "system" fn Java_net_ostp_client_OstpClientSdk_addLog( + mut env: JNIEnv, + _class: JClass, + log_msg: JString, +) { + if let Ok(s) = env.get_string(&log_msg) { + let text: String = s.into(); + add_log(text); + } +}