mirror of https://github.com/ospab/ostp.git
Refactor: Phase 1 and 2 - Async architecture, JNI fixes, SmolTCP data races, and Tunnel optimizations
This commit is contained in:
parent
84797f55ab
commit
29e9ef739c
|
|
@ -0,0 +1,15 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512">
|
||||
<defs>
|
||||
<linearGradient id="g2" x1="0%" y1="0%" x2="100%" y2="100%">
|
||||
<stop offset="0%" stop-color="#111827" />
|
||||
<stop offset="100%" stop-color="#374151" />
|
||||
</linearGradient>
|
||||
<linearGradient id="g2_path" x1="0%" y1="0%" x2="100%" y2="100%">
|
||||
<stop offset="0%" stop-color="#3B82F6" />
|
||||
<stop offset="100%" stop-color="#14B8A6" />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<rect width="512" height="512" rx="120" fill="url(#g2)" />
|
||||
<path d="M144 256c0-61.9 50.1-112 112-112s112 50.1 112 112-50.1 112-112 112S144 317.9 144 256zm-48 0c0 88.4 71.6 160 160 160s160-71.6 160-160S344.4 96 256 96 96 167.6 96 256z" fill="url(#g2_path)"/>
|
||||
<circle cx="256" cy="256" r="40" fill="#F59E0B" />
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 779 B |
|
|
@ -16,6 +16,7 @@ pub(super) struct VirtualDevice {
|
|||
in_buf: UnboundedReceiver<Vec<u8>>,
|
||||
out_buf: Sender<AnyIpPktFrame>,
|
||||
mtu: usize,
|
||||
cached_packet: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl VirtualDevice {
|
||||
|
|
@ -31,6 +32,7 @@ impl VirtualDevice {
|
|||
in_buf: iface_ingress_rx,
|
||||
out_buf: iface_egress_tx,
|
||||
mtu,
|
||||
cached_packet: None,
|
||||
},
|
||||
iface_ingress_tx,
|
||||
iface_ingress_tx_avail,
|
||||
|
|
@ -43,12 +45,18 @@ impl Device for VirtualDevice {
|
|||
type TxToken<'a> = VirtualTxToken<'a>;
|
||||
|
||||
fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> {
|
||||
let Ok(buffer) = self.in_buf.try_recv() else {
|
||||
self.in_buf_avail.store(false, Ordering::Release);
|
||||
return None;
|
||||
let buffer = if let Some(buf) = self.cached_packet.take() {
|
||||
buf
|
||||
} else {
|
||||
let Ok(buf) = self.in_buf.try_recv() else {
|
||||
self.in_buf_avail.store(false, Ordering::Release);
|
||||
return None;
|
||||
};
|
||||
buf
|
||||
};
|
||||
|
||||
let Ok(permit) = self.out_buf.try_reserve() else {
|
||||
self.cached_packet = Some(buffer);
|
||||
self.in_buf_avail.store(false, Ordering::Release);
|
||||
return None;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -12,23 +12,24 @@ use std::{
|
|||
/// require two sets of API interfaces in single-threaded and multi-threaded.
|
||||
///
|
||||
/// [BoxFuture in crate futures utils]: https://docs.rs/futures-util/latest/futures_util/future/type.BoxFuture.html
|
||||
pub struct BoxFuture<'a, T>(Pin<Box<dyn Future<Output = T> + 'a>>);
|
||||
pub struct BoxFuture<'a, T>(Pin<Box<dyn Future<Output = T> + Send + 'a>>);
|
||||
|
||||
impl<'a, T> BoxFuture<'a, T> {
|
||||
pub fn new<F>(f: F) -> BoxFuture<'a, T>
|
||||
where
|
||||
F: IntoFuture<Output = T> + 'a,
|
||||
F: IntoFuture<Output = T> + Send + 'a,
|
||||
F::IntoFuture: Send + 'a,
|
||||
{
|
||||
BoxFuture(Box::pin(f.into_future()))
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn wrap(f: Pin<Box<dyn Future<Output = T> + 'a>>) -> BoxFuture<'a, T> {
|
||||
pub fn wrap(f: Pin<Box<dyn Future<Output = T> + Send + 'a>>) -> BoxFuture<'a, T> {
|
||||
BoxFuture(f)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl<T: Send> Send for BoxFuture<'_, T> {}
|
||||
|
||||
|
||||
impl<T> Future for BoxFuture<'_, T> {
|
||||
type Output = T;
|
||||
|
|
|
|||
|
|
@ -29,3 +29,4 @@ libc = "0.2.186"
|
|||
x25519-dalek = "2.0.1"
|
||||
chacha20poly1305.workspace = true
|
||||
hex = "0.4.3"
|
||||
winapi = { version = "0.3.9", features = ["iphlpapi", "tcpmib", "processthreadsapi", "psapi", "handleapi", "winerror", "minwindef", "winnt"] }
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -58,7 +58,7 @@ pub struct OstpConfig {
|
|||
|
||||
fn default_keepalive() -> u64 { 5 }
|
||||
|
||||
fn default_mtu() -> usize { 1280 }
|
||||
fn default_mtu() -> usize { 1140 }
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LocalProxyConfig {
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ pub fn enable_windows_proxy(proxy_addr: &str) {
|
|||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub fn disable_windows_proxy() {
|
||||
pub fn disable_system_proxy() {
|
||||
tracing::info!("Disabling Windows system proxy");
|
||||
let _ = Command::new("reg")
|
||||
.creation_flags(CREATE_NO_WINDOW)
|
||||
|
|
@ -188,10 +188,6 @@ pub fn enable_system_proxy(proxy_addr: &str) {
|
|||
enable_windows_proxy(proxy_addr);
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
pub fn disable_system_proxy() {
|
||||
disable_windows_proxy();
|
||||
}
|
||||
|
||||
pub struct SystemProxyGuard {
|
||||
active: bool,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,134 @@
|
|||
use crate::config::ExclusionConfig;
|
||||
use std::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ExclusionMatcher {
|
||||
pub domain_suffix: Vec<String>,
|
||||
pub cidrs: Vec<Cidr>,
|
||||
pub processes: Vec<String>,
|
||||
pub physical_if_index: Option<u32>,
|
||||
pub physical_if_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ExclusionMatcher {
|
||||
pub fn new(
|
||||
exclusions: &ExclusionConfig,
|
||||
physical_if_index: Option<u32>,
|
||||
physical_if_name: Option<String>,
|
||||
) -> Self {
|
||||
let mut cidrs = Vec::new();
|
||||
for ip in &exclusions.ips {
|
||||
if let Some(cidr) = parse_cidr(ip) {
|
||||
cidrs.push(cidr);
|
||||
}
|
||||
}
|
||||
|
||||
let processes = exclusions.processes.iter()
|
||||
.map(|p| p.trim().to_lowercase())
|
||||
.filter(|p| !p.is_empty())
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
domain_suffix: exclusions
|
||||
.domains
|
||||
.iter()
|
||||
.map(|d| d.trim().trim_start_matches('.').to_lowercase())
|
||||
.filter(|d| !d.is_empty())
|
||||
.collect(),
|
||||
cidrs,
|
||||
processes,
|
||||
physical_if_index,
|
||||
physical_if_name,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn should_bypass_target(&self, host: &str, port: u16, timeout_value: Duration) -> bool {
|
||||
if self.match_domain(host) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if self.cidrs.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||
return self.match_ip(&ip);
|
||||
}
|
||||
|
||||
let lookup_target = (host.to_string(), port);
|
||||
match timeout(timeout_value, tokio::net::lookup_host(lookup_target)).await {
|
||||
Ok(Ok(addrs)) => addrs.into_iter().any(|addr| self.match_ip(&addr.ip())),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn match_domain(&self, host: &str) -> bool {
|
||||
if self.domain_suffix.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let host = host.trim_end_matches('.').to_lowercase();
|
||||
self.domain_suffix.iter().any(|suffix| {
|
||||
host == *suffix || host.ends_with(&format!(".{suffix}"))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn match_ip(&self, ip: &std::net::IpAddr) -> bool {
|
||||
self.cidrs.iter().any(|cidr| cidr.contains(ip))
|
||||
}
|
||||
|
||||
pub fn match_process(&self, process_name: &str) -> bool {
|
||||
if self.processes.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let p = process_name.to_lowercase();
|
||||
self.processes.iter().any(|ex| p.contains(ex))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Cidr {
|
||||
V4(u32, u8),
|
||||
V6(u128, u8),
|
||||
}
|
||||
|
||||
impl Cidr {
|
||||
pub fn contains(&self, ip: &std::net::IpAddr) -> bool {
|
||||
match (self, ip) {
|
||||
(Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => {
|
||||
let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) };
|
||||
let ip = u32::from_be_bytes(addr.octets());
|
||||
(ip & mask) == (*net & mask)
|
||||
}
|
||||
(Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => {
|
||||
let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) };
|
||||
let ip = u128::from_be_bytes(addr.octets());
|
||||
(ip & mask) == (*net & mask)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn parse_cidr(s: &str) -> Option<Cidr> {
|
||||
let parts: Vec<&str> = s.split('/').collect();
|
||||
if parts.is_empty() || parts.len() > 2 {
|
||||
return None;
|
||||
}
|
||||
if let Ok(ip) = parts[0].parse::<std::net::IpAddr>() {
|
||||
let bits = if parts.len() == 2 {
|
||||
parts[1].parse::<u8>().ok()?
|
||||
} else {
|
||||
match ip {
|
||||
std::net::IpAddr::V4(_) => 32,
|
||||
std::net::IpAddr::V6(_) => 128,
|
||||
}
|
||||
};
|
||||
match ip {
|
||||
std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits)),
|
||||
std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits)),
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
|
@ -61,3 +61,7 @@ pub async fn run_local_proxy(
|
|||
}
|
||||
|
||||
|
||||
|
||||
pub mod exclusion;
|
||||
pub mod process_lookup;
|
||||
pub mod sni_sniff;
|
||||
|
|
|
|||
|
|
@ -426,12 +426,63 @@ pub async fn run_native_tunnel_from_fd(
|
|||
}
|
||||
});
|
||||
|
||||
let matcher = crate::tunnel::exclusion::ExclusionMatcher::new(&config.exclusions, None, None);
|
||||
|
||||
let mut tcp_accept_task = tokio::spawn(async move {
|
||||
if let Some(mut listener) = tcp_listener {
|
||||
while let Some((mut stream, _local, remote)) = listener.next().await {
|
||||
while let Some((mut stream, local, remote)) = listener.next().await {
|
||||
let proxy_addr = proxy_addr.clone();
|
||||
let matcher = matcher.clone();
|
||||
tokio::spawn(async move {
|
||||
if debug { tracing::info!("Native TUN intercepted TCP to {}", remote); }
|
||||
if debug { tracing::info!("Native TUN intercepted TCP {local} -> {remote}"); }
|
||||
|
||||
// Peak first chunk to see SNI
|
||||
let mut sniff_buf = [0u8; 1500];
|
||||
let sniff_len = match tokio::time::timeout(std::time::Duration::from_millis(50), stream.read(&mut sniff_buf)).await {
|
||||
Ok(Ok(n)) => n,
|
||||
_ => 0, // Timeout or error
|
||||
};
|
||||
|
||||
let mut should_bypass = false;
|
||||
|
||||
// 1. Check SNI
|
||||
if sniff_len > 0 {
|
||||
if let Some(sni) = crate::tunnel::sni_sniff::extract_sni(&sniff_buf[..sniff_len]) {
|
||||
if debug { tracing::info!("Native TUN sniffed SNI: {}", sni); }
|
||||
if matcher.match_domain(&sni) {
|
||||
should_bypass = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check Process
|
||||
if !should_bypass {
|
||||
if let Some(exe) = crate::tunnel::process_lookup::get_process_name_from_port(local.port()) {
|
||||
if debug { tracing::info!("Native TUN source port {} maps to EXE: {}", local.port(), exe); }
|
||||
if matcher.match_process(&exe) {
|
||||
should_bypass = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Check Target IP
|
||||
if !should_bypass {
|
||||
if matcher.match_ip(&remote.ip()) {
|
||||
should_bypass = true;
|
||||
}
|
||||
}
|
||||
|
||||
if should_bypass {
|
||||
if debug { tracing::info!("Native TUN BYPASS matched for {}", remote); }
|
||||
if let Ok(mut direct) = tokio::time::timeout(std::time::Duration::from_secs(5), tokio::net::TcpStream::connect(remote)).await.unwrap_or(Err(std::io::Error::new(std::io::ErrorKind::TimedOut, "Direct connect timeout"))) {
|
||||
if sniff_len > 0 {
|
||||
let _ = direct.write_all(&sniff_buf[..sniff_len]).await;
|
||||
}
|
||||
let _ = tokio::io::copy_bidirectional(&mut stream, &mut direct).await;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if let Ok(mut socks) = tokio::net::TcpStream::connect(&proxy_addr).await {
|
||||
if socks.write_all(&[5, 1, 0]).await.is_err() { return; }
|
||||
let mut buf = [0u8; 2];
|
||||
|
|
@ -456,6 +507,11 @@ pub async fn run_native_tunnel_from_fd(
|
|||
let mut rep = [0u8; 10];
|
||||
if socks.read_exact(&mut rep).await.is_err() || rep[1] != 0 { return; }
|
||||
|
||||
// Write sniffed buffer to socks
|
||||
if sniff_len > 0 {
|
||||
if socks.write_all(&sniff_buf[..sniff_len]).await.is_err() { return; }
|
||||
}
|
||||
|
||||
let _ = tokio::io::copy_bidirectional(&mut stream, &mut socks).await;
|
||||
}
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,142 @@
|
|||
#[cfg(target_os = "windows")]
|
||||
pub fn get_process_name_from_port(port: u16) -> Option<String> {
|
||||
use winapi::shared::minwindef::{DWORD, ULONG};
|
||||
use winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER;
|
||||
use winapi::um::iphlpapi::GetExtendedTcpTable;
|
||||
use winapi::shared::tcpmib::{MIB_TCPTABLE_OWNER_PID, MIB_TCPROW_OWNER_PID};
|
||||
|
||||
let mut size: ULONG = 0;
|
||||
let table_class = 5; // TCP_TABLE_OWNER_PID_ALL
|
||||
let mut table = vec![0u8; 1024];
|
||||
|
||||
unsafe {
|
||||
let mut ret = GetExtendedTcpTable(
|
||||
table.as_mut_ptr() as *mut _,
|
||||
&mut size,
|
||||
0,
|
||||
2, // AF_INET
|
||||
table_class,
|
||||
0,
|
||||
);
|
||||
|
||||
if ret == ERROR_INSUFFICIENT_BUFFER {
|
||||
table.resize(size as usize, 0);
|
||||
ret = GetExtendedTcpTable(
|
||||
table.as_mut_ptr() as *mut _,
|
||||
&mut size,
|
||||
0,
|
||||
2, // AF_INET
|
||||
table_class,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
if ret == 0 {
|
||||
let tcp_table = &*(table.as_ptr() as *const MIB_TCPTABLE_OWNER_PID);
|
||||
let row_ptr = &tcp_table.table[0] as *const MIB_TCPROW_OWNER_PID;
|
||||
for i in 0..tcp_table.dwNumEntries {
|
||||
let row = &*row_ptr.add(i as usize);
|
||||
// Local port is in network byte order
|
||||
let local_port = u16::from_be(row.dwLocalPort as u16);
|
||||
if local_port == port {
|
||||
return get_process_name_from_pid(row.dwOwningPid);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn get_process_name_from_pid(pid: u32) -> Option<String> {
|
||||
use winapi::um::processthreadsapi::OpenProcess;
|
||||
use winapi::um::psapi::GetModuleBaseNameW;
|
||||
use winapi::um::winnt::{PROCESS_QUERY_INFORMATION, PROCESS_VM_READ};
|
||||
use winapi::um::handleapi::CloseHandle;
|
||||
use std::os::windows::ffi::OsStringExt;
|
||||
|
||||
unsafe {
|
||||
let handle = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, 0, pid);
|
||||
if handle.is_null() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut buffer = [0u16; 1024];
|
||||
let len = GetModuleBaseNameW(handle, std::ptr::null_mut(), buffer.as_mut_ptr(), buffer.len() as u32);
|
||||
CloseHandle(handle);
|
||||
|
||||
if len > 0 {
|
||||
let name = std::ffi::OsString::from_wide(&buffer[..len as usize]);
|
||||
return Some(name.to_string_lossy().into_owned());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
pub fn get_process_name_from_port(port: u16) -> Option<String> {
|
||||
use std::fs;
|
||||
use std::io::{BufRead, BufReader};
|
||||
|
||||
let mut target_inode = None;
|
||||
let hex_port = format!("{:04X}", port);
|
||||
|
||||
let check_net_file = |path: &str| -> Option<u64> {
|
||||
let file = fs::File::open(path).ok()?;
|
||||
let reader = BufReader::new(file);
|
||||
for line in reader.lines().skip(1).filter_map(Result::ok) {
|
||||
let parts: Vec<&str> = line.split_whitespace().collect();
|
||||
if parts.len() >= 10 {
|
||||
let local_addr = parts[1];
|
||||
if local_addr.ends_with(&format!(":{}", hex_port)) {
|
||||
if let Ok(inode) = parts[9].parse::<u64>() {
|
||||
return Some(inode);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
target_inode = check_net_file("/proc/net/tcp")
|
||||
.or_else(|| check_net_file("/proc/net/tcp6"))
|
||||
.or_else(|| check_net_file("/proc/net/udp"))
|
||||
.or_else(|| check_net_file("/proc/net/udp6"));
|
||||
|
||||
let target_inode = target_inode?;
|
||||
let socket_str = format!("socket:[{}]", target_inode);
|
||||
|
||||
for entry in fs::read_dir("/proc").ok()?.filter_map(Result::ok) {
|
||||
let file_name = entry.file_name();
|
||||
let pid_str = file_name.to_string_lossy();
|
||||
if !pid_str.chars().all(char::is_numeric) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let fd_dir = entry.path().join("fd");
|
||||
if let Ok(fd_entries) = fs::read_dir(fd_dir) {
|
||||
for fd_entry in fd_entries.filter_map(Result::ok) {
|
||||
if let Ok(target) = fs::read_link(fd_entry.path()) {
|
||||
if target.to_string_lossy() == socket_str {
|
||||
let exe_path = entry.path().join("exe");
|
||||
if let Ok(exe_link) = fs::read_link(exe_path) {
|
||||
if let Some(name) = exe_link.file_name() {
|
||||
return Some(name.to_string_lossy().into_owned());
|
||||
}
|
||||
}
|
||||
if let Ok(comm) = fs::read_to_string(entry.path().join("comm")) {
|
||||
return Some(comm.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_os = "windows", target_os = "linux")))]
|
||||
pub fn get_process_name_from_port(_port: u16) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
use std::collections::HashMap;
|
||||
use crate::tunnel::exclusion::{ExclusionMatcher, Cidr};
|
||||
use anyhow::{anyhow, Context, Result};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||
|
|
@ -421,8 +422,10 @@ async fn handle_udp_associate(
|
|||
};
|
||||
let payload = bytes::Bytes::copy_from_slice(&buf[header_len..len]);
|
||||
|
||||
let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() };
|
||||
let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 };
|
||||
// Check if target should bypass the tunnel
|
||||
if matcher.should_bypass(&target, connect_timeout).await {
|
||||
if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await {
|
||||
if debug {
|
||||
tracing::info!("proxy UDP BYPASS target={}", target);
|
||||
}
|
||||
|
|
@ -668,7 +671,9 @@ async fn handle_proxy_client(
|
|||
if debug {
|
||||
tracing::info!("proxy CONNECT stream_id={stream_id} target={target}");
|
||||
}
|
||||
if matcher.should_bypass(&target, connect_timeout).await {
|
||||
let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() };
|
||||
let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 0 };
|
||||
if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await {
|
||||
return direct_connect_socks5(
|
||||
client,
|
||||
stream_id,
|
||||
|
|
@ -753,7 +758,9 @@ async fn handle_proxy_client(
|
|||
if debug {
|
||||
tracing::info!("proxy CONNECT stream_id={stream_id} target={target}");
|
||||
}
|
||||
if matcher.should_bypass(&target, connect_timeout).await {
|
||||
let target_host = if let Some((host, _)) = split_host_port(&target) { host } else { target.clone() };
|
||||
let target_port = match split_host_port(&target) { Some((_, p)) => p, None => 443 };
|
||||
if matcher.should_bypass_target(&target_host, target_port, connect_timeout).await {
|
||||
return direct_connect_http(
|
||||
client,
|
||||
stream_id,
|
||||
|
|
@ -854,129 +861,6 @@ async fn handle_proxy_client(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ExclusionMatcher {
|
||||
domain_suffix: Vec<String>,
|
||||
cidrs: Vec<Cidr>,
|
||||
physical_if_index: Option<u32>,
|
||||
physical_if_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ExclusionMatcher {
|
||||
fn new(
|
||||
exclusions: &ExclusionConfig,
|
||||
physical_if_index: Option<u32>,
|
||||
physical_if_name: Option<String>,
|
||||
) -> Self {
|
||||
let mut cidrs = Vec::new();
|
||||
for ip in &exclusions.ips {
|
||||
if let Some(cidr) = parse_cidr(ip) {
|
||||
cidrs.push(cidr);
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
domain_suffix: exclusions
|
||||
.domains
|
||||
.iter()
|
||||
.map(|d| d.trim().trim_start_matches('.').to_lowercase())
|
||||
.filter(|d| !d.is_empty())
|
||||
.collect(),
|
||||
cidrs,
|
||||
physical_if_index,
|
||||
physical_if_name,
|
||||
}
|
||||
}
|
||||
|
||||
async fn should_bypass(&self, target: &str, timeout_value: Duration) -> bool {
|
||||
let (host, port) = match split_host_port(target) {
|
||||
Some(v) => v,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
if self.match_domain(&host) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if self.cidrs.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||
return self.match_ip(&ip);
|
||||
}
|
||||
|
||||
let lookup_target = (host.clone(), port);
|
||||
match timeout(timeout_value, tokio::net::lookup_host(lookup_target)).await {
|
||||
Ok(Ok(addrs)) => addrs.into_iter().any(|addr| self.match_ip(&addr.ip())),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn match_domain(&self, host: &str) -> bool {
|
||||
if self.domain_suffix.is_empty() {
|
||||
return false;
|
||||
}
|
||||
let host = host.trim_end_matches('.').to_lowercase();
|
||||
self.domain_suffix.iter().any(|suffix| {
|
||||
host == *suffix || host.ends_with(&format!(".{suffix}"))
|
||||
})
|
||||
}
|
||||
|
||||
fn match_ip(&self, ip: &std::net::IpAddr) -> bool {
|
||||
self.cidrs.iter().any(|cidr| cidr.contains(ip))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum Cidr {
|
||||
V4(u32, u8),
|
||||
V6(u128, u8),
|
||||
}
|
||||
|
||||
impl Cidr {
|
||||
fn contains(&self, ip: &std::net::IpAddr) -> bool {
|
||||
match (self, ip) {
|
||||
(Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => {
|
||||
let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) };
|
||||
let ip = u32::from_be_bytes(addr.octets());
|
||||
(ip & mask) == (*net & mask)
|
||||
}
|
||||
(Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => {
|
||||
let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) };
|
||||
let ip = u128::from_be_bytes(addr.octets());
|
||||
(ip & mask) == (*net & mask)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_cidr(value: &str) -> Option<Cidr> {
|
||||
let value = value.trim();
|
||||
if value.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some((addr_str, bits_str)) = value.split_once('/') {
|
||||
let bits: u8 = bits_str.parse().ok()?;
|
||||
if let Ok(addr) = addr_str.parse::<std::net::IpAddr>() {
|
||||
return match addr {
|
||||
std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits.min(32))),
|
||||
std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits.min(128))),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(addr) = value.parse::<std::net::IpAddr>() {
|
||||
return match addr {
|
||||
std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), 32)),
|
||||
std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), 128)),
|
||||
};
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn split_host_port(target: &str) -> Option<(String, u16)> {
|
||||
if let Some((host, port)) = target.rsplit_once(':') {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
pub fn extract_sni(data: &[u8]) -> Option<String> {
|
||||
// Basic TLS ClientHello parser
|
||||
// Must be at least 43 bytes to contain anything useful
|
||||
if data.len() < 43 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// TLS Record layer: Handshake (22)
|
||||
if data[0] != 0x16 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Record layer version: 0x0301 (TLS 1.0) or 0x0303 (TLS 1.2)
|
||||
if data[1] != 0x03 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Handshake type: ClientHello (1)
|
||||
if data[5] != 0x01 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut pos = 43; // Skip fixed ClientHello header
|
||||
|
||||
// Skip Session ID
|
||||
if pos >= data.len() { return None; }
|
||||
let session_id_len = data[pos] as usize;
|
||||
pos += 1 + session_id_len;
|
||||
|
||||
// Skip Cipher Suites
|
||||
if pos + 2 > data.len() { return None; }
|
||||
let cipher_suites_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
|
||||
pos += 2 + cipher_suites_len;
|
||||
|
||||
// Skip Compression Methods
|
||||
if pos >= data.len() { return None; }
|
||||
let comp_methods_len = data[pos] as usize;
|
||||
pos += 1 + comp_methods_len;
|
||||
|
||||
// Extensions
|
||||
if pos + 2 > data.len() { return None; }
|
||||
let extensions_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
|
||||
pos += 2;
|
||||
|
||||
let extensions_end = pos + extensions_len;
|
||||
if extensions_end > data.len() { return None; }
|
||||
|
||||
while pos + 4 <= extensions_end {
|
||||
let ext_type = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
|
||||
let ext_len = ((data[pos + 2] as usize) << 8) | (data[pos + 3] as usize);
|
||||
pos += 4;
|
||||
|
||||
if ext_type == 0x0000 { // Server Name Indication (SNI)
|
||||
if pos + 5 <= extensions_end {
|
||||
let list_len = ((data[pos] as usize) << 8) | (data[pos + 1] as usize);
|
||||
let name_type = data[pos + 2];
|
||||
if name_type == 0 { // Hostname
|
||||
let name_len = ((data[pos + 3] as usize) << 8) | (data[pos + 4] as usize);
|
||||
if pos + 5 + name_len <= extensions_end {
|
||||
let sni_bytes = &data[pos + 5..pos + 5 + name_len];
|
||||
if let Ok(sni) = std::str::from_utf8(sni_bytes) {
|
||||
return Some(sni.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
pos += ext_len;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
|
@ -5,7 +5,6 @@
|
|||
//! This replaces the fixed `retransmit_budget = 8` with an adaptive
|
||||
//! congestion window that responds to network conditions.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Congestion control state for a single OSTP session.
|
||||
|
|
@ -16,14 +15,8 @@ pub struct CongestionController {
|
|||
ssthresh: u64,
|
||||
/// Current phase
|
||||
phase: Phase,
|
||||
/// Minimum RTT observed (used for BDP calculation)
|
||||
/// Minimum RTT observed
|
||||
min_rtt: Duration,
|
||||
/// Maximum bandwidth observed (bytes/sec)
|
||||
max_bandwidth: u64,
|
||||
/// RTT samples for smoothing
|
||||
rtt_samples: VecDeque<RttSample>,
|
||||
/// Bandwidth samples
|
||||
bw_samples: VecDeque<BwSample>,
|
||||
/// Bytes currently in flight (unacknowledged)
|
||||
bytes_in_flight: u64,
|
||||
/// Total bytes acknowledged (for bandwidth estimation)
|
||||
|
|
@ -36,8 +29,6 @@ pub struct CongestionController {
|
|||
pacing_rate: u64,
|
||||
/// MTU estimate (used for cwnd → packet count conversion)
|
||||
mtu: u64,
|
||||
/// Probe RTT phase timer
|
||||
probe_rtt_timer: Option<Instant>,
|
||||
/// Min RTT expiry: re-probe after 10 seconds
|
||||
min_rtt_stamp: Instant,
|
||||
}
|
||||
|
|
@ -48,35 +39,14 @@ enum Phase {
|
|||
SlowStart,
|
||||
/// Probe bandwidth: cycle through pacing gains
|
||||
ProbeBandwidth,
|
||||
/// Periodically drain the queue to measure true min RTT
|
||||
#[allow(dead_code)]
|
||||
ProbeRtt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
struct RttSample {
|
||||
rtt: Duration,
|
||||
time: Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)]
|
||||
struct BwSample {
|
||||
bytes_per_sec: u64,
|
||||
time: Instant,
|
||||
}
|
||||
|
||||
/// Maximum number of samples to keep for windowed min/max
|
||||
const MAX_SAMPLES: usize = 32;
|
||||
/// Initial congestion window: 10 packets × MTU
|
||||
const INITIAL_CWND_PACKETS: u64 = 10;
|
||||
/// Minimum cwnd: 2 packets
|
||||
const MIN_CWND_PACKETS: u64 = 2;
|
||||
/// Min RTT expiry window (after which we re-probe)
|
||||
const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10);
|
||||
/// ProbeRTT drain duration
|
||||
const PROBE_RTT_DURATION: Duration = Duration::from_millis(200);
|
||||
|
||||
impl CongestionController {
|
||||
pub fn new(mtu: u64) -> Self {
|
||||
|
|
@ -87,16 +57,12 @@ impl CongestionController {
|
|||
ssthresh: u64::MAX,
|
||||
phase: Phase::SlowStart,
|
||||
min_rtt: Duration::from_millis(100), // Conservative initial estimate
|
||||
max_bandwidth: 0,
|
||||
rtt_samples: VecDeque::with_capacity(MAX_SAMPLES),
|
||||
bw_samples: VecDeque::with_capacity(MAX_SAMPLES),
|
||||
bytes_in_flight: 0,
|
||||
total_acked: 0,
|
||||
last_ack_time: now,
|
||||
loss_count: 0,
|
||||
pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec
|
||||
mtu,
|
||||
probe_rtt_timer: None,
|
||||
min_rtt_stamp: now,
|
||||
}
|
||||
}
|
||||
|
|
@ -169,30 +135,8 @@ impl CongestionController {
|
|||
// TCP Reno Additive Increase: increase cwnd by ~1 MTU per RTT
|
||||
self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1));
|
||||
}
|
||||
Phase::ProbeRtt => {
|
||||
// Drain down to 4 packets to measure true min RTT
|
||||
self.cwnd = MIN_CWND_PACKETS * self.mtu * 2;
|
||||
if let Some(timer) = self.probe_rtt_timer {
|
||||
if now.duration_since(timer) >= PROBE_RTT_DURATION {
|
||||
// ProbeRTT complete, return to ProbeBandwidth
|
||||
self.phase = Phase::ProbeBandwidth;
|
||||
self.probe_rtt_timer = None;
|
||||
self.cwnd = (MIN_CWND_PACKETS * self.mtu * 4).max(self.cwnd);
|
||||
tracing::debug!(cwnd = self.cwnd, min_rtt = ?self.min_rtt, "congestion: probe RTT complete");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
// Periodically enter ProbeRTT to refresh min_rtt
|
||||
if now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY && self.phase != Phase::ProbeRtt {
|
||||
self.phase = Phase::ProbeRtt;
|
||||
self.probe_rtt_timer = Some(now);
|
||||
tracing::debug!("congestion: entering probe RTT phase");
|
||||
}
|
||||
*/
|
||||
|
||||
self.update_pacing_rate();
|
||||
self.last_ack_time = now;
|
||||
}
|
||||
|
|
@ -215,9 +159,6 @@ impl CongestionController {
|
|||
self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu);
|
||||
tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced");
|
||||
}
|
||||
Phase::ProbeRtt => {
|
||||
// Don't react to loss during ProbeRTT
|
||||
}
|
||||
}
|
||||
|
||||
self.update_pacing_rate();
|
||||
|
|
@ -236,40 +177,16 @@ impl CongestionController {
|
|||
self.min_rtt = rtt;
|
||||
self.min_rtt_stamp = now;
|
||||
}
|
||||
|
||||
// Keep sample history
|
||||
self.rtt_samples.push_back(RttSample { rtt, time: now });
|
||||
while self.rtt_samples.len() > MAX_SAMPLES {
|
||||
self.rtt_samples.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
fn update_bandwidth(&mut self, acked_bytes: u64, now: Instant) {
|
||||
fn update_bandwidth(&mut self, _acked_bytes: u64, now: Instant) {
|
||||
let elapsed = now.duration_since(self.last_ack_time);
|
||||
if elapsed.as_micros() > 0 {
|
||||
let bw = acked_bytes * 1_000_000 / elapsed.as_micros() as u64;
|
||||
if bw > self.max_bandwidth {
|
||||
self.max_bandwidth = bw;
|
||||
}
|
||||
self.bw_samples.push_back(BwSample { bytes_per_sec: bw, time: now });
|
||||
while self.bw_samples.len() > MAX_SAMPLES {
|
||||
self.bw_samples.pop_front();
|
||||
}
|
||||
// Removed bw_samples tracking
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bandwidth_delay_product(&self) -> u64 {
|
||||
// BDP = max_bandwidth * min_rtt
|
||||
let bw = if self.max_bandwidth > 0 {
|
||||
self.max_bandwidth
|
||||
} else {
|
||||
// Fallback: assume 10 Mbps
|
||||
1_250_000
|
||||
};
|
||||
let rtt_secs = self.min_rtt.as_secs_f64();
|
||||
(bw as f64 * rtt_secs) as u64
|
||||
}
|
||||
|
||||
|
||||
fn update_pacing_rate(&mut self) {
|
||||
// Pacing rate = cwnd / min_rtt (with gain)
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ impl ProtocolMachine {
|
|||
if raw_vec.len() < 12 {
|
||||
return Err(ProtocolError::Framing("data datagram too short".to_string()));
|
||||
}
|
||||
let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().unwrap());
|
||||
let nonce = u64::from_be_bytes(raw_vec[4..12].try_into().map_err(|_| ProtocolError::Framing("data datagram too short for nonce".into()))?);
|
||||
|
||||
if nonce < self.expected_recv_nonce {
|
||||
// Duplicate — the ACK we sent was likely lost or delayed.
|
||||
|
|
@ -330,7 +330,7 @@ impl ProtocolMachine {
|
|||
// Fast path processing for Nacks: act immediately, bypass sequence queue
|
||||
if packet.header.kind == FrameKind::Nack
|
||||
&& packet.payload.len() >= 8 {
|
||||
let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().unwrap());
|
||||
let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().map_err(|_| ProtocolError::Framing("nack payload too short".into()))?);
|
||||
if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) {
|
||||
tracing::debug!("NACK received: retransmitting nonce={}", req_nonce);
|
||||
self.cc.on_loss(cached_frame.len() as u64);
|
||||
|
|
@ -733,8 +733,8 @@ fn parse_ack_ranges(payload: &[u8]) -> Result<Vec<(u64, u64)>, ProtocolError> {
|
|||
let mut ranges = Vec::with_capacity(count);
|
||||
let mut idx = 1;
|
||||
for _ in 0..count {
|
||||
let start = u64::from_be_bytes(payload[idx..idx + 8].try_into().unwrap());
|
||||
let end = u64::from_be_bytes(payload[idx + 8..idx + 16].try_into().unwrap());
|
||||
let start = u64::from_be_bytes(payload[idx..idx + 8].try_into().map_err(|_| ProtocolError::Framing("ack range start invalid".into()))?);
|
||||
let end = u64::from_be_bytes(payload[idx + 8..idx + 16].try_into().map_err(|_| ProtocolError::Framing("ack range end invalid".into()))?);
|
||||
ranges.push((start, end));
|
||||
idx += 16;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ impl RelayMessage {
|
|||
7 => {
|
||||
let payload = decode_with_len(&input[1..])?;
|
||||
if payload.len() != 8 { return Err(anyhow!("invalid ping payload len")); }
|
||||
let ts = u64::from_be_bytes(payload.try_into().unwrap());
|
||||
let ts = u64::from_be_bytes(payload.try_into().map_err(|_| anyhow!("invalid ping payload size"))?);
|
||||
Ok(RelayMessage::Ping(ts))
|
||||
}
|
||||
8 => {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
description: This file stores settings for Dart & Flutter DevTools.
|
||||
documentation: https://docs.flutter.dev/tools/devtools/extensions#configure-extension-enablement-states
|
||||
extensions:
|
||||
|
|
@ -96,6 +96,14 @@ packages:
|
|||
description: flutter
|
||||
source: sdk
|
||||
version: "0.0.0"
|
||||
json_annotation:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: json_annotation
|
||||
sha256: "2a743920d81b7910627f68ee2c9ac1fc0bfee32b9fc3403587d7c6791ca12f80"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "4.12.0"
|
||||
leak_tracker:
|
||||
dependency: transitive
|
||||
description:
|
||||
|
|
@ -144,6 +152,14 @@ packages:
|
|||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.13.0"
|
||||
menu_base:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: menu_base
|
||||
sha256: "820368014a171bd1241030278e6c2617354f492f5c703d7b7d4570a6b8b84405"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.1.1"
|
||||
meta:
|
||||
dependency: transitive
|
||||
description:
|
||||
|
|
@ -208,6 +224,46 @@ packages:
|
|||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.1.8"
|
||||
screen_retriever:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: screen_retriever
|
||||
sha256: "570dbc8e4f70bac451e0efc9c9bb19fa2d6799a11e6ef04f946d7886d2e23d0c"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.2.0"
|
||||
screen_retriever_linux:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: screen_retriever_linux
|
||||
sha256: f7f8120c92ef0784e58491ab664d01efda79a922b025ff286e29aa123ea3dd18
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.2.0"
|
||||
screen_retriever_macos:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: screen_retriever_macos
|
||||
sha256: "71f956e65c97315dd661d71f828708bd97b6d358e776f1a30d5aa7d22d78a149"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.2.0"
|
||||
screen_retriever_platform_interface:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: screen_retriever_platform_interface
|
||||
sha256: ee197f4581ff0d5608587819af40490748e1e39e648d7680ecf95c05197240c0
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.2.0"
|
||||
screen_retriever_windows:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: screen_retriever_windows
|
||||
sha256: "449ee257f03ca98a57288ee526a301a430a344a161f9202b4fcc38576716fe13"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.2.0"
|
||||
shared_preferences:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
|
|
@ -264,6 +320,14 @@ packages:
|
|||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "2.4.1"
|
||||
shortid:
|
||||
dependency: transitive
|
||||
description:
|
||||
name: shortid
|
||||
sha256: d0b40e3dbb50497dad107e19c54ca7de0d1a274eb9b4404991e443dadb9ebedb
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.1.2"
|
||||
sky_engine:
|
||||
dependency: transitive
|
||||
description: flutter
|
||||
|
|
@ -317,6 +381,14 @@ packages:
|
|||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.7.10"
|
||||
tray_manager:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: tray_manager
|
||||
sha256: c5fd83b0ae4d80be6eaedfad87aaefab8787b333b8ebd064b0e442a81006035b
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.5.2"
|
||||
vector_math:
|
||||
dependency: transitive
|
||||
description:
|
||||
|
|
@ -341,6 +413,14 @@ packages:
|
|||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "1.1.1"
|
||||
window_manager:
|
||||
dependency: "direct main"
|
||||
description:
|
||||
name: window_manager
|
||||
sha256: "7eb6d6c4164ec08e1bf978d6e733f3cebe792e2a23fb07cbca25c2872bfdbdcd"
|
||||
url: "https://pub.dev"
|
||||
source: hosted
|
||||
version: "0.5.1"
|
||||
xdg_directories:
|
||||
dependency: transitive
|
||||
description:
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ dependencies:
|
|||
cupertino_icons: ^1.0.8
|
||||
shared_preferences: ^2.5.5
|
||||
mobile_scanner: ^5.0.0
|
||||
window_manager: ^0.5.1
|
||||
tray_manager: ^0.5.2
|
||||
|
||||
dev_dependencies:
|
||||
flutter_test:
|
||||
|
|
|
|||
|
|
@ -2641,7 +2641,7 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "ostp-client"
|
||||
version = "0.2.73"
|
||||
version = "0.2.79"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.22.1",
|
||||
|
|
@ -2666,12 +2666,13 @@ dependencies = [
|
|||
"tracing",
|
||||
"tun",
|
||||
"webpki-roots 0.26.11",
|
||||
"winapi",
|
||||
"x25519-dalek",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ostp-core"
|
||||
version = "0.2.73"
|
||||
version = "0.2.79"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ struct HelperState {
|
|||
pipe_state: Arc<Mutex<HelperPipeState>>,
|
||||
cmd_tx: tokio::sync::mpsc::Sender<String>,
|
||||
token: String,
|
||||
port: u16,
|
||||
}
|
||||
|
||||
enum TunnelHandle {
|
||||
|
|
@ -282,6 +283,37 @@ async fn get_metrics(state: tauri::State<'_, AppState>) -> Result<Option<UIMetri
|
|||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn reload_tunnel(state: tauri::State<'_, AppState>) -> Result<bool, String> {
|
||||
let mut guard = state.0.lock().await;
|
||||
if guard.tunnel.is_none() {
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let config_path = get_config_path();
|
||||
let config_str = std::fs::read_to_string(&config_path)
|
||||
.map_err(|e| format!("Read config error: {}", e))?;
|
||||
|
||||
match &guard.tunnel {
|
||||
Some(TunnelHandle::Helper(h)) => {
|
||||
let cmd = format!(
|
||||
"{{\"cmd\":\"reload\",\"config\":{},\"token\":\"{}\"}}\n",
|
||||
serde_json::to_string(&config_str).unwrap(),
|
||||
h.token
|
||||
);
|
||||
let _ = h.cmd_tx.send(cmd).await;
|
||||
}
|
||||
Some(TunnelHandle::InProcess(s)) => {
|
||||
// Restarting in-process tunnel is not supported without re-calling start_tunnel,
|
||||
// but we can just abort and we should really call start_tunnel again.
|
||||
// For now, return false.
|
||||
return Ok(false);
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn stop_tunnel(state: tauri::State<'_, AppState>) -> Result<bool, String> {
|
||||
let mut guard = state.0.lock().await;
|
||||
|
|
@ -375,24 +407,19 @@ async fn start_tun_via_helper(
|
|||
guard: &mut AppStateInner,
|
||||
raw: &ClientConfigRaw,
|
||||
) -> Result<bool, String> {
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
// Kill any existing helper processes to prevent os error 10048 (port already in use)
|
||||
use std::os::windows::process::CommandExt;
|
||||
let _ = std::process::Command::new("taskkill")
|
||||
.args(["/F", "/IM", "ostp-tun-helper.exe"])
|
||||
.creation_flags(0x08000000)
|
||||
.output();
|
||||
}
|
||||
let port = {
|
||||
let listener = std::net::TcpListener::bind("127.0.0.1:0").map_err(|e| format!("Bind error: {}", e))?;
|
||||
listener.local_addr().unwrap().port()
|
||||
};
|
||||
|
||||
let auth_token = rand::random::<u64>().to_string();
|
||||
let helper_exe = find_helper_exe().ok_or_else(|| "ostp-tun-helper.exe not found.".to_string())?;
|
||||
launch_as_admin(&helper_exe, &auth_token).map_err(|e| format!("Failed to launch helper: {}", e))?;
|
||||
launch_as_admin(&helper_exe, &auth_token, port).map_err(|e| format!("Failed to launch helper: {}", e))?;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(1500)).await;
|
||||
|
||||
let socket = tokio::time::timeout(std::time::Duration::from_secs(60), async {
|
||||
loop {
|
||||
match tokio::net::TcpStream::connect("127.0.0.1:53211").await {
|
||||
match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", port)).await {
|
||||
Ok(s) => return Ok::<_, std::io::Error>(s),
|
||||
Err(_) => tokio::time::sleep(std::time::Duration::from_millis(200)).await,
|
||||
}
|
||||
|
|
@ -443,7 +470,7 @@ async fn start_tun_via_helper(
|
|||
state_for_task.lock().await.connection_state = 0;
|
||||
});
|
||||
|
||||
guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token }));
|
||||
guard.tunnel = Some(TunnelHandle::Helper(HelperState { pipe_state, cmd_tx, token: auth_token, port }));
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
|
|
@ -493,14 +520,14 @@ fn find_helper_exe() -> Option<PathBuf> {
|
|||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()> {
|
||||
fn launch_as_admin(exe: &std::path::PathBuf, token: &str, port: u16) -> anyhow::Result<()> {
|
||||
use std::ffi::OsStr;
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
use std::ptr::null_mut;
|
||||
|
||||
let exe_wstr: Vec<u16> = exe.as_os_str().encode_wide().chain(Some(0)).collect();
|
||||
let verb_wstr: Vec<u16> = OsStr::new("runas").encode_wide().chain(Some(0)).collect();
|
||||
let params_str = format!("--token {}", token);
|
||||
let params_str = format!("--token {} --port {}", token, port);
|
||||
let params_wstr: Vec<u16> = OsStr::new(¶ms_str).encode_wide().chain(Some(0)).collect();
|
||||
#[link(name = "shell32")] extern "system" { fn ShellExecuteW(h: *mut std::ffi::c_void, op: *const u16, f: *const u16, p: *const u16, d: *const u16, s: i32) -> isize; }
|
||||
|
||||
|
|
@ -514,7 +541,7 @@ fn launch_as_admin(exe: &std::path::PathBuf, token: &str) -> anyhow::Result<()>
|
|||
}
|
||||
|
||||
#[cfg(not(target_os = "windows"))]
|
||||
fn launch_as_admin(_exe: &PathBuf, _token: &str) -> Result<()> { anyhow::bail!("Windows only."); }
|
||||
fn launch_as_admin(_exe: &PathBuf, _token: &str, _port: u16) -> Result<()> { anyhow::bail!("Windows only."); }
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
|
|
@ -607,7 +634,7 @@ pub fn run() {
|
|||
}
|
||||
_ => {}
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![start_tunnel, stop_tunnel, get_tunnel_status, get_metrics, get_config, save_config])
|
||||
.invoke_handler(tauri::generate_handler![start_tunnel, stop_tunnel, reload_tunnel, get_tunnel_status, get_metrics, get_config, save_config])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -319,6 +319,8 @@ async function handleSave(silent = false) {
|
|||
|
||||
rawConfig.tun = rawConfig.tun || {};
|
||||
rawConfig.tun.enable = inTun.checked;
|
||||
rawConfig.tun.wintun_path = rawConfig.tun.wintun_path || './wintun.dll';
|
||||
rawConfig.tun.ipv4_address = rawConfig.tun.ipv4_address || '10.1.0.2/24';
|
||||
rawConfig.tun.stack = 'ostp';
|
||||
// owndns: if toggle is on, always write 10.1.0.1; otherwise use the custom field
|
||||
rawConfig.tun.dns = inOwndns.checked ? '10.1.0.1' : (inDns.value.trim() || null);
|
||||
|
|
@ -477,7 +479,14 @@ window.addEventListener('DOMContentLoaded', async () => {
|
|||
// Auto-save wiring
|
||||
const formInputs = document.querySelectorAll('#settings-screen input:not(#in-import-url), #settings-screen textarea, #settings-screen select');
|
||||
formInputs.forEach(el => {
|
||||
el.addEventListener('input', scheduleAutoSave);
|
||||
el.addEventListener('input', () => {
|
||||
scheduleAutoSave();
|
||||
if (appState === 'connected') {
|
||||
if (window.__TAURI__ && window.__TAURI__.invoke) {
|
||||
window.__TAURI__.invoke('reload_tunnel');
|
||||
}
|
||||
}
|
||||
});
|
||||
el.addEventListener('change', scheduleAutoSave);
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -32,12 +32,14 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
|
||||
// ── Native JNI bindings ───────────────────────────────────────────────────
|
||||
|
||||
private external fun nativeStartClient(configJson: String): Boolean
|
||||
private external fun nativeStopClient(): Boolean
|
||||
private external fun nativeGetMetrics(): String
|
||||
private external fun nativeGetLogs(): String
|
||||
private external fun startClient(configJson: String, fd: Int, t2sBinPath: String, localProxy: String): Boolean
|
||||
private external fun stopClient(): Boolean
|
||||
private external fun getMetrics(): String
|
||||
private external fun getLogs(): String
|
||||
private external fun addLog(logMsg: String)
|
||||
private external fun notifyNetworkChanged()
|
||||
|
||||
|
||||
// ── Public data models ────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
|
|
@ -175,7 +177,8 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
_state.value = TunnelState.Connecting
|
||||
|
||||
val json = config.toNativeJson()
|
||||
val ok = nativeStartClient(json)
|
||||
// Default values for fd, t2sBinPath, localProxy for proxy mode
|
||||
val ok = startClient(json, -1, "", config.proxyBind)
|
||||
if (!ok) {
|
||||
_state.value = TunnelState.Failed("Native layer rejected config")
|
||||
started.set(false)
|
||||
|
|
@ -197,7 +200,7 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
|
||||
pollingJob?.cancel()
|
||||
networkCallbackJob?.cancel()
|
||||
nativeStopClient()
|
||||
stopClient()
|
||||
unregisterNetworkCallback()
|
||||
_state.value = TunnelState.Idle
|
||||
emitLog("OSTP SDK stopped")
|
||||
|
|
@ -209,7 +212,7 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
*/
|
||||
fun drainLogs(): List<String> {
|
||||
return try {
|
||||
val array = JSONArray(nativeGetLogs())
|
||||
val array = JSONArray(getLogs())
|
||||
(0 until array.length()).map { array.getString(it) }
|
||||
} catch (_: Exception) {
|
||||
emptyList()
|
||||
|
|
@ -218,7 +221,7 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
|
||||
/** Read the latest [Metrics] snapshot. Returns zeroed metrics if tunnel is idle. */
|
||||
fun getMetrics(): Metrics {
|
||||
return parseMetrics(nativeGetMetrics())
|
||||
return parseMetrics(getMetrics())
|
||||
}
|
||||
|
||||
// ── Internal helpers ──────────────────────────────────────────────────────
|
||||
|
|
@ -247,7 +250,7 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
}
|
||||
|
||||
// Update state based on metrics availability
|
||||
val metrics = parseMetrics(nativeGetMetrics())
|
||||
val metrics = parseMetrics(getMetrics())
|
||||
if (wasConnected) {
|
||||
_state.value = TunnelState.Connected(metrics)
|
||||
}
|
||||
|
|
@ -320,6 +323,33 @@ class OstpClientSdk private constructor(private val context: Context) {
|
|||
@Volatile
|
||||
private var instance: OstpClientSdk? = null
|
||||
|
||||
@JvmStatic
|
||||
fun protectSocket(fd: Int): Boolean {
|
||||
var retries = 5
|
||||
while (retries > 0) {
|
||||
// We use reflection or explicit class to get the VpnService instance
|
||||
try {
|
||||
val serviceClass = Class.forName("com.ospab.ostp_client.OstpVpnService")
|
||||
val instanceField = serviceClass.getDeclaredField("instance")
|
||||
instanceField.isAccessible = true
|
||||
val service = instanceField.get(null)
|
||||
if (service != null) {
|
||||
val protectMethod = serviceClass.getMethod("protect", Int::class.javaPrimitiveType)
|
||||
val res = protectMethod.invoke(service, fd) as Boolean
|
||||
android.util.Log.i("OstpClientSdk", "VpnService.protect(socketFd=$fd) -> success=$res")
|
||||
return res
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
android.util.Log.w("OstpClientSdk", "Error accessing VpnService via reflection: \${e.message}")
|
||||
}
|
||||
android.util.Log.w("OstpClientSdk", "VpnService instance not available! Retrying... (\$retries left)")
|
||||
Thread.sleep(200)
|
||||
retries--
|
||||
}
|
||||
android.util.Log.e("OstpClientSdk", "VpnService instance is null! Cannot protect socketFd=\$fd")
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the singleton SDK instance.
|
||||
* Must be called with an Application context to avoid memory leaks.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use jni::sys::{jboolean, jstring};
|
|||
use jni::JNIEnv;
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::{atomic::Ordering, Arc, Mutex};
|
||||
use std::sync::{atomic::Ordering, Arc, Mutex, RwLock};
|
||||
use tokio::runtime::Runtime;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use ostp_client::bridge::{Bridge, BridgeMetrics};
|
||||
|
|
@ -80,13 +80,13 @@ impl SdkState {
|
|||
}
|
||||
}
|
||||
|
||||
static STATE: Mutex<SdkState> = Mutex::new(SdkState::new());
|
||||
static LOGS: Mutex<VecDeque<String>> = Mutex::new(VecDeque::new());
|
||||
static JVM: Mutex<Option<jni::JavaVM>> = Mutex::new(None);
|
||||
static CLASS_REF: Mutex<Option<jni::objects::GlobalRef>> = Mutex::new(None);
|
||||
static STATE: RwLock<SdkState> = RwLock::new(SdkState::new());
|
||||
static LOGS: RwLock<VecDeque<String>> = RwLock::new(VecDeque::new());
|
||||
static JVM: RwLock<Option<jni::JavaVM>> = RwLock::new(None);
|
||||
static CLASS_REF: RwLock<Option<jni::objects::GlobalRef>> = RwLock::new(None);
|
||||
|
||||
fn add_log(text: String) {
|
||||
if let Ok(mut guard) = LOGS.lock() {
|
||||
if let Ok(mut guard) = LOGS.write() {
|
||||
if guard.len() >= 1000 {
|
||||
guard.pop_front();
|
||||
}
|
||||
|
|
@ -95,7 +95,7 @@ fn add_log(text: String) {
|
|||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeStartClient(
|
||||
mut env: JNIEnv,
|
||||
_class: JClass,
|
||||
config_json: JString,
|
||||
|
|
@ -103,7 +103,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
t2s_bin_path: JString,
|
||||
local_proxy: JString,
|
||||
) -> jboolean {
|
||||
let mut state = match STATE.lock() {
|
||||
let mut state = match STATE.write() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return jni::sys::JNI_FALSE,
|
||||
};
|
||||
|
|
@ -116,25 +116,25 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
init_tracing();
|
||||
|
||||
if let Ok(jvm) = env.get_java_vm() {
|
||||
if let Ok(mut guard) = JVM.lock() {
|
||||
if let Ok(mut guard) = JVM.write() {
|
||||
*guard = Some(jvm);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(cls) = env.find_class("net/ostp/client/OstpClientSdk") {
|
||||
if let Ok(global_cls) = env.new_global_ref(cls) {
|
||||
if let Ok(mut guard) = CLASS_REF.lock() {
|
||||
if let Ok(mut guard) = CLASS_REF.write() {
|
||||
*guard = Some(global_cls);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ostp_client::bridge::set_socket_protector(|fd| {
|
||||
let jvm_guard = match JVM.lock() {
|
||||
let jvm_guard = match JVM.read() {
|
||||
Ok(g) => g,
|
||||
Err(_) => return false,
|
||||
};
|
||||
let class_guard = match CLASS_REF.lock() {
|
||||
let class_guard = match CLASS_REF.read() {
|
||||
Ok(g) => g,
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
|
@ -346,12 +346,24 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
|||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_startClient(
|
||||
env: JNIEnv,
|
||||
class: JClass,
|
||||
config_json: JString,
|
||||
fd: jni::sys::jint,
|
||||
t2s_bin_path: JString,
|
||||
local_proxy: JString,
|
||||
) -> jboolean {
|
||||
Java_net_ostp_client_OstpClientSdk_nativeStartClient(env, class, config_json, fd, t2s_bin_path, local_proxy)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeStopClient(
|
||||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
) -> jboolean {
|
||||
let (tun_child, shutdown_tx, runtime) = {
|
||||
let mut state = match STATE.lock() {
|
||||
let mut state = match STATE.write() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return jni::sys::JNI_FALSE,
|
||||
};
|
||||
|
|
@ -381,11 +393,19 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
|
|||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics(
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_stopClient(
|
||||
env: JNIEnv,
|
||||
class: JClass,
|
||||
) -> jboolean {
|
||||
Java_net_ostp_client_OstpClientSdk_nativeStopClient(env, class)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeGetMetrics(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
) -> jstring {
|
||||
let state = match STATE.lock() {
|
||||
let state = match STATE.read() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return match env.new_string("{}") {
|
||||
Ok(s) => s.into_raw(),
|
||||
|
|
@ -415,11 +435,19 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics(
|
|||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs(
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getMetrics(
|
||||
env: JNIEnv,
|
||||
class: JClass,
|
||||
) -> jstring {
|
||||
Java_net_ostp_client_OstpClientSdk_nativeGetMetrics(env, class)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_nativeGetLogs(
|
||||
env: JNIEnv,
|
||||
_class: JClass,
|
||||
) -> jstring {
|
||||
let logs_vec: Vec<String> = match LOGS.lock() {
|
||||
let logs_vec: Vec<String> = match LOGS.write() {
|
||||
Ok(mut guard) => guard.drain(..).collect(),
|
||||
Err(_) => Vec::new(),
|
||||
};
|
||||
|
|
@ -435,6 +463,14 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs(
|
|||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_getLogs(
|
||||
env: JNIEnv,
|
||||
class: JClass,
|
||||
) -> jstring {
|
||||
Java_net_ostp_client_OstpClientSdk_nativeGetLogs(env, class)
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "system" fn Java_net_ostp_client_OstpClientSdk_addLog(
|
||||
mut env: JNIEnv,
|
||||
|
|
@ -454,7 +490,7 @@ pub extern "system" fn Java_net_ostp_client_OstpClientSdk_notifyNetworkChanged(
|
|||
_env: JNIEnv,
|
||||
_class: JClass,
|
||||
) {
|
||||
let state = match STATE.lock() {
|
||||
let state = match STATE.read() {
|
||||
Ok(s) => s,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ fn check_token(state: &ApiState, headers: &axum::http::HeaderMap) -> bool {
|
|||
if let Some(value) = headers.get("authorization") {
|
||||
if let Ok(val) = value.to_str() {
|
||||
if let Some(token) = val.strip_prefix("Bearer ") {
|
||||
let current_session = state.session_token.read().unwrap().clone();
|
||||
let current_session = state.session_token.read().unwrap_or_else(|e| e.into_inner()).clone();
|
||||
if let Some(session) = current_session {
|
||||
if token == session {
|
||||
allowed = true;
|
||||
|
|
@ -353,7 +353,7 @@ async fn handle_login(
|
|||
|
||||
if hash_hex == state.password_hash {
|
||||
let token = uuid::Uuid::new_v4().to_string();
|
||||
*state.session_token.write().unwrap() = Some(token.clone());
|
||||
*state.session_token.write().unwrap_or_else(|e| e.into_inner()) = Some(token.clone());
|
||||
(StatusCode::OK, ApiResponse::success(LoginResponse { token }))
|
||||
} else {
|
||||
api_unauthorized::<LoginResponse>()
|
||||
|
|
@ -377,7 +377,7 @@ fn save_config_keys(state: &ApiState) -> Result<(), String> {
|
|||
let mut json_val: serde_json::Value = serde_json::from_str(&content_str)
|
||||
.map_err(|e| format!("failed to parse config JSON: {}", e))?;
|
||||
|
||||
let keys = state.access_keys.read().unwrap();
|
||||
let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
let mut access_keys_json = Vec::new();
|
||||
for (k, m) in keys.iter() {
|
||||
if m.name.is_none() && m.limit_bytes.is_none() {
|
||||
|
|
@ -511,8 +511,8 @@ async fn handle_status(
|
|||
return api_unauthorized::<ServerStatus>();
|
||||
}
|
||||
|
||||
let keys = state.access_keys.read().unwrap();
|
||||
let stats = state.user_stats.read().unwrap();
|
||||
let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
let online = stats.values()
|
||||
.filter(|us| {
|
||||
let total = us.bytes_up.load(Ordering::Relaxed) + us.bytes_down.load(Ordering::Relaxed);
|
||||
|
|
@ -538,8 +538,8 @@ async fn handle_list_users(
|
|||
return api_unauthorized::<Vec<UserStatsSnapshot>>();
|
||||
}
|
||||
|
||||
let keys = state.access_keys.read().unwrap();
|
||||
let stats = state.user_stats.read().unwrap();
|
||||
let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
|
||||
let mut users: Vec<UserStatsSnapshot> = keys.iter().map(|(key, meta)| {
|
||||
if let Some(us) = stats.get(key) {
|
||||
|
|
@ -579,13 +579,13 @@ async fn handle_get_user(
|
|||
return api_unauthorized::<UserStatsSnapshot>();
|
||||
}
|
||||
|
||||
let keys = state.access_keys.read().unwrap();
|
||||
let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
let meta = match keys.get(&key) {
|
||||
Some(m) => m.clone(),
|
||||
None => return api_error("user not found"),
|
||||
};
|
||||
|
||||
let stats = state.user_stats.read().unwrap();
|
||||
let stats = state.user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
let snapshot = if let Some(us) = stats.get(&key) {
|
||||
UserStatsSnapshot {
|
||||
access_key: key.clone(),
|
||||
|
|
@ -628,11 +628,11 @@ async fn handle_create_user(
|
|||
});
|
||||
|
||||
{
|
||||
let mut keys = state.access_keys.write().unwrap();
|
||||
let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
|
||||
keys.insert(key.clone(), UserMeta { name: body.name.clone(), limit_bytes: body.limit_bytes });
|
||||
}
|
||||
|
||||
let mut stats = state.user_stats.write().unwrap();
|
||||
let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
stats.insert(key.clone(), Arc::new(UserStats::new(body.limit_bytes)));
|
||||
drop(stats);
|
||||
|
||||
|
|
@ -655,14 +655,14 @@ async fn delete_user(
|
|||
}
|
||||
|
||||
{
|
||||
let mut keys = state.access_keys.write().unwrap();
|
||||
let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
|
||||
if keys.remove(&key).is_none() {
|
||||
return api_error::<String>("User not found");
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut stats = state.user_stats.write().unwrap();
|
||||
let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
stats.remove(&key);
|
||||
}
|
||||
|
||||
|
|
@ -685,7 +685,7 @@ async fn update_user(
|
|||
}
|
||||
|
||||
{
|
||||
let mut keys = state.access_keys.write().unwrap();
|
||||
let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(meta) = keys.get_mut(&key) {
|
||||
meta.name = body.name.clone();
|
||||
meta.limit_bytes = body.limit_bytes;
|
||||
|
|
@ -695,7 +695,7 @@ async fn update_user(
|
|||
}
|
||||
|
||||
{
|
||||
let mut stats = state.user_stats.write().unwrap();
|
||||
let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
let entry = stats.entry(key.clone())
|
||||
.or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes)));
|
||||
|
||||
|
|
@ -727,7 +727,7 @@ async fn handle_set_limit(
|
|||
}
|
||||
|
||||
{
|
||||
let mut keys = state.access_keys.write().unwrap();
|
||||
let mut keys = state.access_keys.write().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(meta) = keys.get_mut(&key) {
|
||||
meta.limit_bytes = body.limit_bytes;
|
||||
} else {
|
||||
|
|
@ -735,7 +735,7 @@ async fn handle_set_limit(
|
|||
}
|
||||
}
|
||||
|
||||
let mut stats = state.user_stats.write().unwrap();
|
||||
let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
let entry = stats.entry(key.clone())
|
||||
.or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes)));
|
||||
|
||||
|
|
@ -765,7 +765,7 @@ async fn handle_reset_stats(
|
|||
return api_unauthorized::<bool>();
|
||||
}
|
||||
|
||||
let mut stats = state.user_stats.write().unwrap();
|
||||
let mut stats = state.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(us) = stats.get(&key) {
|
||||
let limit = us.limit_bytes;
|
||||
stats.insert(key.clone(), Arc::new(UserStats::new(limit)));
|
||||
|
|
@ -793,7 +793,7 @@ async fn handle_subscribe(
|
|||
|
||||
// Validate that the key exists in a tightly scoped block to drop the guard
|
||||
let key_exists = {
|
||||
let keys = state.access_keys.read().unwrap();
|
||||
let keys = state.access_keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
keys.contains_key(&key)
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ pub struct Dispatcher {
|
|||
impl Dispatcher {
|
||||
pub fn new(machine_config: ProtocolConfig, access_keys: Arc<RwLock<HashMap<String, crate::api::UserMeta>>>) -> Self {
|
||||
let mut initial_stats = HashMap::new();
|
||||
for (key, meta) in access_keys.read().unwrap().iter() {
|
||||
for (key, meta) in access_keys.read().unwrap_or_else(|e| e.into_inner()).iter() {
|
||||
initial_stats.insert(key.clone(), Arc::new(UserStats::new(meta.limit_bytes)));
|
||||
}
|
||||
Self {
|
||||
|
|
@ -108,7 +108,7 @@ impl Dispatcher {
|
|||
|
||||
/// Snapshot all user stats for API responses.
|
||||
pub fn snapshot_all_users(&self) -> Vec<UserStatsSnapshot> {
|
||||
let stats = self.user_stats.read().unwrap();
|
||||
let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
let online_keys: std::collections::HashSet<String> = self.peer_machines.values()
|
||||
.map(|ps| ps.access_key.clone())
|
||||
.collect();
|
||||
|
|
@ -125,15 +125,15 @@ impl Dispatcher {
|
|||
|
||||
/// Get or create stats entry for a user key.
|
||||
fn get_or_create_user_stats(&self, key: &str) -> Arc<UserStats> {
|
||||
let stats = self.user_stats.read().unwrap();
|
||||
let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(existing) = stats.get(key) {
|
||||
return existing.clone();
|
||||
}
|
||||
drop(stats);
|
||||
|
||||
let limit_bytes = self.access_keys.read().unwrap().get(key).and_then(|m| m.limit_bytes);
|
||||
let limit_bytes = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).get(key).and_then(|m| m.limit_bytes);
|
||||
|
||||
let mut stats = self.user_stats.write().unwrap();
|
||||
let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
stats.entry(key.to_string())
|
||||
.or_insert_with(|| Arc::new(UserStats::new(limit_bytes)))
|
||||
.clone()
|
||||
|
|
@ -141,7 +141,7 @@ impl Dispatcher {
|
|||
|
||||
/// Set traffic limit for a user.
|
||||
pub fn set_user_limit(&self, key: &str, limit: Option<u64>) {
|
||||
let mut stats = self.user_stats.write().unwrap();
|
||||
let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
let entry = stats.entry(key.to_string())
|
||||
.or_insert_with(|| Arc::new(UserStats::new(limit)));
|
||||
// Replace the entry with new limit (stats reset)
|
||||
|
|
@ -212,7 +212,7 @@ impl Dispatcher {
|
|||
let key_opt = self.peer_machines.get(&session_id).map(|ps| ps.access_key.clone());
|
||||
if let Some(access_key) = key_opt {
|
||||
// Check if key is still valid and not over limit
|
||||
let key_valid = self.access_keys.read().unwrap().contains_key(&access_key);
|
||||
let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&access_key);
|
||||
let user_stats = self.get_or_create_user_stats(&access_key);
|
||||
if !key_valid || user_stats.is_over_limit() {
|
||||
tracing::info!("Dropping session {} for key {} (valid={}, over_limit={})",
|
||||
|
|
@ -280,7 +280,7 @@ impl Dispatcher {
|
|||
}
|
||||
|
||||
// Not an existing session — try each registered access key's derived obfuscation key
|
||||
let keys_snapshot: Vec<String> = self.access_keys.read().unwrap().keys().cloned().collect();
|
||||
let keys_snapshot: Vec<String> = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).keys().cloned().collect();
|
||||
|
||||
for candidate_key in keys_snapshot {
|
||||
let secrets = ostp_core::crypto::derive_all_secrets(candidate_key.as_bytes());
|
||||
|
|
@ -430,7 +430,7 @@ impl Dispatcher {
|
|||
|
||||
// Gather expired or invalid sessions
|
||||
for (&sid, peer_state) in &self.peer_machines {
|
||||
let key_valid = self.access_keys.read().unwrap().contains_key(&peer_state.access_key);
|
||||
let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&peer_state.access_key);
|
||||
let user_stats = self.get_or_create_user_stats(&peer_state.access_key);
|
||||
if now.duration_since(peer_state.last_seen) > timeout_dur || !key_valid || user_stats.is_over_limit() {
|
||||
expired.push(sid);
|
||||
|
|
@ -441,7 +441,7 @@ impl Dispatcher {
|
|||
for sid in &expired {
|
||||
let peer_state_opt = self.peer_machines.get(sid);
|
||||
let reason = if let Some(ps) = peer_state_opt {
|
||||
let key_valid = self.access_keys.read().unwrap().contains_key(&ps.access_key);
|
||||
let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&ps.access_key);
|
||||
let user_stats = self.get_or_create_user_stats(&ps.access_key);
|
||||
if now.duration_since(ps.last_seen) > timeout_dur {
|
||||
"inactive >5min"
|
||||
|
|
@ -504,15 +504,15 @@ fn get_or_create_stats(
|
|||
key: &str,
|
||||
) -> Arc<UserStats> {
|
||||
{
|
||||
let stats = user_stats.read().unwrap();
|
||||
let stats = user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(existing) = stats.get(key) {
|
||||
return existing.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let limit_bytes = access_keys.read().unwrap().get(key).and_then(|m| m.limit_bytes);
|
||||
let limit_bytes = access_keys.read().unwrap_or_else(|e| e.into_inner()).get(key).and_then(|m| m.limit_bytes);
|
||||
|
||||
let mut stats = user_stats.write().unwrap();
|
||||
let mut stats = user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
stats.entry(key.to_string())
|
||||
.or_insert_with(|| Arc::new(UserStats::new(limit_bytes)))
|
||||
.clone()
|
||||
|
|
|
|||
|
|
@ -195,11 +195,11 @@ pub async fn run_server(
|
|||
}
|
||||
|
||||
// 1. Update shared_keys
|
||||
let mut keys_lock = shared_keys_clone.write().unwrap();
|
||||
let mut keys_lock = shared_keys_clone.write().unwrap_or_else(|e| e.into_inner());
|
||||
*keys_lock = new_keys.clone();
|
||||
|
||||
// 2. Synchronize user_stats limits & cleanup deleted keys
|
||||
let mut stats_lock = user_stats_clone.write().unwrap();
|
||||
let mut stats_lock = user_stats_clone.write().unwrap_or_else(|e| e.into_inner());
|
||||
stats_lock.retain(|k, _| new_keys.contains_key(k));
|
||||
|
||||
for (k, meta) in &new_keys {
|
||||
|
|
@ -308,7 +308,7 @@ pub async fn run_server(
|
|||
}
|
||||
});
|
||||
|
||||
let key_count = shared_keys.read().unwrap().len();
|
||||
let key_count = shared_keys.read().unwrap_or_else(|e| e.into_inner()).len();
|
||||
tracing::info!(listeners = bind_addrs.len(), keys = key_count, "server started");
|
||||
tracing::info!("ARQ config: max_reorder=16384, reorder_buf=8192, sent_history=32768, rto=100ms");
|
||||
let reality_config_arc = reality_config.map(std::sync::Arc::new);
|
||||
|
|
@ -434,7 +434,7 @@ async fn run_server_loop(
|
|||
|
||||
if debug {
|
||||
let _ = ui_event_tx.send(UiEvent::Log("Server loop started".to_string()));
|
||||
let _ = ui_event_tx.send(UiEvent::KeyCount(shared_keys.read().unwrap().len()));
|
||||
let _ = ui_event_tx.send(UiEvent::KeyCount(shared_keys.read().unwrap_or_else(|e| e.into_inner()).len()));
|
||||
}
|
||||
|
||||
let mut retransmit_tick = interval(Duration::from_millis(10));
|
||||
|
|
@ -448,7 +448,7 @@ async fn run_server_loop(
|
|||
match cmd {
|
||||
Some(UiCommand::CreateClientKey) => {
|
||||
let key = format!("ostp_key_{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs());
|
||||
shared_keys.write().unwrap().insert(key.clone(), crate::api::UserMeta { name: None, limit_bytes: None });
|
||||
shared_keys.write().unwrap_or_else(|e| e.into_inner()).insert(key.clone(), crate::api::UserMeta { name: None, limit_bytes: None });
|
||||
let _ = ui_event_tx.send(UiEvent::KeyCreated { key });
|
||||
}
|
||||
Some(UiCommand::Shutdown) | None => {
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ pub async fn run_relay_node(cfg: RelayConfig) -> Result<()> {
|
|||
if let Err(e) = sync_keys(&cfg, &shared_keys).await {
|
||||
tracing::warn!("Relay: initial key sync failed: {}. Will retry.", e);
|
||||
} else {
|
||||
let count = shared_keys.read().unwrap().len();
|
||||
let count = shared_keys.read().unwrap_or_else(|e| e.into_inner()).len();
|
||||
tracing::info!("Relay: synced {} access key(s) from upstream API", count);
|
||||
}
|
||||
|
||||
|
|
@ -190,7 +190,7 @@ async fn run_udp_relay(cfg: RelayConfig, shared_keys: SharedKeys) -> Result<()>
|
|||
|
||||
let ts_bytes: [u8; 8] = packet[0..8].try_into().unwrap();
|
||||
let provided_mac = &packet[8..40];
|
||||
let keys_guard = keys.read().unwrap();
|
||||
let keys_guard = keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
|
||||
if !verify_hmac(&ts_bytes, provided_mac, &keys_guard) {
|
||||
tracing::debug!("Relay UDP: unauthorized probe from {}, dropped", peer);
|
||||
|
|
@ -369,7 +369,7 @@ async fn handle_tcp_client(
|
|||
|
||||
// Проверяем по синхронизированным ключам
|
||||
let authorized = {
|
||||
let keys = shared_keys.read().unwrap();
|
||||
let keys = shared_keys.read().unwrap_or_else(|e| e.into_inner());
|
||||
verify_hmac(&ts_bytes, provided_mac, &keys)
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ use anyhow::Result;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::fs::OpenOptions;
|
||||
use std::io::Write as _;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::sync::{watch, Mutex};
|
||||
|
|
@ -13,21 +12,25 @@ use tokio::net::TcpListener;
|
|||
use portable_atomic::Ordering;
|
||||
|
||||
fn log_to_file(msg: &str) {
|
||||
let path = std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|d| d.join("ostp-helper.log")))
|
||||
.unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log"));
|
||||
if let Ok(mut file) = OpenOptions::new().create(true).append(true).open(path) {
|
||||
let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg);
|
||||
}
|
||||
let msg = msg.to_string();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let path = std::env::current_exe()
|
||||
.ok()
|
||||
.and_then(|p| p.parent().map(|d| d.join("ostp-helper.log")))
|
||||
.unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log"));
|
||||
if let Ok(mut file) = std::fs::OpenOptions::new().create(true).append(true).open(path) {
|
||||
let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
const BIND_ADDR: &str = "127.0.0.1:53211";
|
||||
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "cmd", rename_all = "lowercase")]
|
||||
enum GuiCmd {
|
||||
Start { config: String, token: String },
|
||||
Reload { config: String, token: String },
|
||||
Stop { token: String },
|
||||
}
|
||||
|
||||
|
|
@ -55,10 +58,13 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
let mut expected_token = String::new();
|
||||
let mut port = 53211u16;
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
for i in 1..args.len() {
|
||||
if args[i] == "--token" && i + 1 < args.len() {
|
||||
expected_token = args[i + 1].clone();
|
||||
} else if args[i] == "--port" && i + 1 < args.len() {
|
||||
port = args[i + 1].parse().unwrap_or(53211);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -69,21 +75,22 @@ async fn main() -> Result<()> {
|
|||
return Err(anyhow::anyhow!("--token argument is required"));
|
||||
}
|
||||
|
||||
if let Err(e) = run_server(expected_token).await {
|
||||
if let Err(e) = run_server(expected_token, port).await {
|
||||
log_to_file(&format!("Fatal error: {}", e));
|
||||
}
|
||||
log_to_file("Helper exiting");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_server(expected_token: String) -> Result<()> {
|
||||
async fn run_server(expected_token: String, port: u16) -> Result<()> {
|
||||
let state = Arc::new(Mutex::new(TunnelState {
|
||||
shutdown_tx: None,
|
||||
metrics: None,
|
||||
}));
|
||||
|
||||
log_to_file(&format!("Attempting to bind to {}", BIND_ADDR));
|
||||
let listener = TcpListener::bind(BIND_ADDR).await.map_err(|e| {
|
||||
let bind_addr = format!("127.0.0.1:{}", port);
|
||||
log_to_file(&format!("Attempting to bind to {}", bind_addr));
|
||||
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
|
||||
log_to_file(&format!("Bind failed: {}", e));
|
||||
e
|
||||
})?;
|
||||
|
|
@ -182,9 +189,10 @@ async fn run_server(expected_token: String) -> Result<()> {
|
|||
|
||||
let metrics_for_runner = metrics.clone();
|
||||
let writer_for_err = writer.clone();
|
||||
let shutdown_rx_for_core = shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
log_to_file("Starting tunnel core...");
|
||||
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx).await {
|
||||
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await {
|
||||
Ok(_) => { log_to_file("Tunnel core stopped normally"); }
|
||||
Err(e) => {
|
||||
log_to_file(&format!("Tunnel core error: {}", e));
|
||||
|
|
@ -197,10 +205,17 @@ async fn run_server(expected_token: String) -> Result<()> {
|
|||
|
||||
let writer_tick = writer.clone();
|
||||
let metrics_tick = metrics.clone();
|
||||
let mut shutdown_rx_tick = shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut last_state = 99u8;
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(Duration::from_secs(1)) => {}
|
||||
_ = shutdown_rx_tick.changed() => {
|
||||
if *shutdown_rx_tick.borrow() { break; }
|
||||
}
|
||||
}
|
||||
|
||||
let cs = metrics_tick.connection_state.load(Ordering::Relaxed);
|
||||
let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed);
|
||||
let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed);
|
||||
|
|
@ -221,6 +236,95 @@ async fn run_server(expected_token: String) -> Result<()> {
|
|||
|
||||
send_msg(HelperMsg::Status { value: 1 });
|
||||
}
|
||||
GuiCmd::Reload { config, token } => {
|
||||
if token != expected_token {
|
||||
send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() });
|
||||
continue;
|
||||
}
|
||||
log_to_file("Received RELOAD command");
|
||||
|
||||
// Signal shutdown to current core
|
||||
{
|
||||
let mut st = state.lock().await;
|
||||
if let Some(tx) = st.shutdown_tx.take() {
|
||||
let _ = tx.send(true);
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(500)).await; // give it time to shutdown cleanly
|
||||
}
|
||||
|
||||
let cfg: ostp_client::config::ClientConfig = match serde_json::from_str(&config) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
send_msg(HelperMsg::Error { message: format!("Config parse error during reload: {}", e) });
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let metrics = Arc::new(ostp_client::bridge::BridgeMetrics {
|
||||
bytes_sent: portable_atomic::AtomicU64::new(0),
|
||||
bytes_recv: portable_atomic::AtomicU64::new(0),
|
||||
connection_state: portable_atomic::AtomicU8::new(0),
|
||||
rtt_ms: portable_atomic::AtomicU32::new(0),
|
||||
});
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
|
||||
{
|
||||
let mut st = state.lock().await;
|
||||
st.shutdown_tx = Some(shutdown_tx);
|
||||
st.metrics = Some(metrics.clone());
|
||||
}
|
||||
|
||||
let metrics_for_runner = metrics.clone();
|
||||
let writer_for_err = writer.clone();
|
||||
let shutdown_rx_for_core = shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
log_to_file("Restarting tunnel core for reload...");
|
||||
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await {
|
||||
Ok(_) => { log_to_file("Reloaded core stopped normally"); }
|
||||
Err(e) => {
|
||||
let json = serde_json::to_string(&HelperMsg::Error { message: e.to_string() }).unwrap_or_default();
|
||||
let mut w = writer_for_err.lock().await;
|
||||
let _ = w.write_all(format!("{}\n", json).as_bytes()).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Status tick loop is already running and using old metrics?
|
||||
// Wait! We re-created metrics, so the old tick loop will continue reporting old metrics (which are disconnected)!
|
||||
// We should probably share the tick loop or spawn a new one and let the old one die.
|
||||
// It's easier if `metrics` in state is a generic watcher, but since we re-spawned it:
|
||||
let writer_tick = writer.clone();
|
||||
let metrics_tick = metrics.clone();
|
||||
let mut shutdown_rx_tick = shutdown_rx.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut last_state = 99u8;
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = tokio::time::sleep(Duration::from_secs(1)) => {}
|
||||
_ = shutdown_rx_tick.changed() => {
|
||||
if *shutdown_rx_tick.borrow() { break; }
|
||||
}
|
||||
}
|
||||
let cs = metrics_tick.connection_state.load(Ordering::Relaxed);
|
||||
let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed);
|
||||
let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed);
|
||||
let rtt = metrics_tick.rtt_ms.load(Ordering::Relaxed);
|
||||
|
||||
let mut w = writer_tick.lock().await;
|
||||
if cs != last_state {
|
||||
last_state = cs;
|
||||
let json = serde_json::to_string(&HelperMsg::Status { value: cs }).unwrap_or_default();
|
||||
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
|
||||
}
|
||||
let json = serde_json::to_string(&HelperMsg::Metrics { bytes_sent: sent, bytes_recv: recv, rtt_ms: rtt }).unwrap_or_default();
|
||||
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
|
||||
drop(w);
|
||||
}
|
||||
});
|
||||
|
||||
send_msg(HelperMsg::Status { value: 1 });
|
||||
}
|
||||
GuiCmd::Stop { token } => {
|
||||
if token != expected_token {
|
||||
log_to_file("Received STOP command with invalid token");
|
||||
|
|
|
|||
|
|
@ -0,0 +1,658 @@
|
|||
import sys
|
||||
import re
|
||||
|
||||
with open("d:/ospab-projects/ostp/ostp-client/src/bridge.rs", "r", encoding="utf-8") as f:
|
||||
code = f.read()
|
||||
|
||||
start_idx = code.find(" pub async fn run(")
|
||||
end_idx = -1
|
||||
brace_count = 0
|
||||
in_run = False
|
||||
for i in range(start_idx, len(code)):
|
||||
if code[i] == '{':
|
||||
in_run = True
|
||||
brace_count += 1
|
||||
elif code[i] == '}':
|
||||
if in_run:
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
end_idx = i + 1
|
||||
break
|
||||
|
||||
prefix = code[:start_idx]
|
||||
suffix = code[end_idx:]
|
||||
|
||||
# Define the new run function and helpers
|
||||
new_run_and_helpers = """
|
||||
pub async fn run(
|
||||
mut self,
|
||||
tx: mpsc::Sender<UiEvent>,
|
||||
mut bridge_rx: mpsc::Receiver<BridgeCommand>,
|
||||
mut shutdown: watch::Receiver<bool>,
|
||||
mut proxy_rx: mpsc::Receiver<ProxyEvent>,
|
||||
proxy_tx: mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
|
||||
) -> Result<()> {
|
||||
let mut metrics_tick = interval(Duration::from_millis(500));
|
||||
let mut keepalive_tick = tokio::time::interval(Duration::from_secs(self.keepalive_interval_sec.max(1)));
|
||||
let mut retransmit_tick = tokio::time::interval(Duration::from_millis(10));
|
||||
let init_msg = if self.mode == "tun" {
|
||||
"Bridge initialized (TUN mode)".to_string()
|
||||
} else {
|
||||
"Bridge initialized (proxy mode)".to_string()
|
||||
};
|
||||
tx.send(UiEvent::Log(init_msg)).await.ok();
|
||||
|
||||
let mut sessions_opt: Option<Vec<SessionState>> = None;
|
||||
let mut udp_rx_opt: Option<mpsc::Receiver<(usize, Bytes)>> = None;
|
||||
let mut proxy_guard: Option<crate::sysproxy::SystemProxyGuard> = None;
|
||||
let mut stream_map: std::collections::HashMap<u16, usize> = std::collections::HashMap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
_ = shutdown.changed() => {
|
||||
if *shutdown.borrow() {
|
||||
self.running = false;
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
proxy_guard = None;
|
||||
sessions_opt = None;
|
||||
udp_rx_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "manual stop");
|
||||
break;
|
||||
}
|
||||
}
|
||||
udp_msg = async {
|
||||
match udp_rx_opt.as_mut() {
|
||||
Some(rx) => rx.recv().await,
|
||||
None => std::future::pending().await,
|
||||
}
|
||||
}, if self.running => {
|
||||
self.handle_inbound_udp(udp_msg, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await;
|
||||
}
|
||||
cmd = bridge_rx.recv() => {
|
||||
if !self.handle_bridge_cmd(cmd, &mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = metrics_tick.tick() => {
|
||||
if self.running {
|
||||
self.emit_metrics(&tx).await;
|
||||
}
|
||||
}
|
||||
_ = keepalive_tick.tick() => {
|
||||
if self.running {
|
||||
self.handle_keepalive(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx, &mut proxy_rx).await;
|
||||
}
|
||||
}
|
||||
_ = retransmit_tick.tick() => {
|
||||
if self.running {
|
||||
self.handle_retransmit(&mut sessions_opt, &mut udp_rx_opt, &mut proxy_guard, &mut stream_map, &tx, &proxy_tx).await;
|
||||
}
|
||||
}
|
||||
proxy_ev = proxy_rx.recv(), if self.running && sessions_opt.as_ref().map(|s| {
|
||||
s.iter().any(|ses| ses.machine.in_flight_count() < ses.machine.cwnd_packets().clamp(16, 16384))
|
||||
}).unwrap_or(true) => {
|
||||
self.handle_proxy_event(proxy_ev, &mut sessions_opt, &mut stream_map, &tx, &proxy_tx).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tx.send(UiEvent::Log("Bridge stopped".to_string())).await.ok();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_inbound_udp(
|
||||
&mut self,
|
||||
udp_msg: Option<(usize, Bytes)>,
|
||||
sessions_opt: &mut Option<Vec<SessionState>>,
|
||||
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
|
||||
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
|
||||
stream_map: &mut std::collections::HashMap<u16, usize>,
|
||||
tx: &mpsc::Sender<UiEvent>,
|
||||
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
|
||||
) {
|
||||
match udp_msg {
|
||||
Some((session_index, inbound)) => {
|
||||
self.metrics.bytes_recv.fetch_add(inbound.len() as u64, Ordering::Relaxed);
|
||||
self.last_valid_recv = Instant::now();
|
||||
if let Some(sessions) = sessions_opt.as_mut() {
|
||||
if session_index < sessions.len() {
|
||||
let session = &mut sessions[session_index];
|
||||
let initial_action = match session.machine.on_event(OstpEvent::Inbound(inbound)) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await;
|
||||
tracing::warn!("Inbound protocol error (session {}): {}", session_index, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut actions_queue = std::collections::VecDeque::new();
|
||||
actions_queue.push_back(initial_action);
|
||||
|
||||
while let Some(current_action) = actions_queue.pop_front() {
|
||||
match current_action {
|
||||
ProtocolAction::Multiple(nested) => {
|
||||
for a in nested {
|
||||
actions_queue.push_back(a);
|
||||
}
|
||||
}
|
||||
ProtocolAction::DeliverApp(stream_id, dec_payload) => {
|
||||
match RelayMessage::decode(&dec_payload) {
|
||||
Ok(relay_msg) => {
|
||||
match relay_msg {
|
||||
RelayMessage::ConnectOk => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Relay CONNECT OK stream_id={stream_id}"))).await;
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::ConnectOk));
|
||||
}
|
||||
RelayMessage::Data(data) => {
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Data(Bytes::from(data))));
|
||||
}
|
||||
RelayMessage::Close => {
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Close));
|
||||
}
|
||||
RelayMessage::Error(msg) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Relay error for stream {stream_id}: {msg}"))).await;
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error(msg)));
|
||||
}
|
||||
RelayMessage::Pong(ts) => {
|
||||
let now = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
|
||||
self.last_rtt_ms = now.saturating_sub(ts) as f64;
|
||||
self.metrics.rtt_ms.store(self.last_rtt_ms as u32, Ordering::Relaxed);
|
||||
}
|
||||
RelayMessage::UdpAssociate => {}
|
||||
RelayMessage::UdpData(target, data) => {
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::UdpData(target, Bytes::from(data))));
|
||||
}
|
||||
RelayMessage::KeepAlive | RelayMessage::Ping(_) | RelayMessage::Connect(_) => {}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Relay decode error for stream {stream_id}: {err}"))).await;
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("relay decode failed".to_string())));
|
||||
}
|
||||
}
|
||||
}
|
||||
ProtocolAction::SendDatagram(frame) => {
|
||||
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let _ = tx.send(UiEvent::Log("UDP channel closed, resetting connection".to_string())).await;
|
||||
self.running = false;
|
||||
crate::sysproxy::disable_system_proxy();
|
||||
*sessions_opt = None;
|
||||
*udp_rx_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "udp reader closed");
|
||||
let _ = tx.send(UiEvent::TunnelStopped).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_bridge_cmd(
|
||||
&mut self,
|
||||
cmd: Option<BridgeCommand>,
|
||||
sessions_opt: &mut Option<Vec<SessionState>>,
|
||||
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
|
||||
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
|
||||
stream_map: &mut std::collections::HashMap<u16, usize>,
|
||||
tx: &mpsc::Sender<UiEvent>,
|
||||
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
|
||||
) -> bool {
|
||||
match cmd {
|
||||
Some(BridgeCommand::ToggleTunnel) => {
|
||||
if self.running {
|
||||
self.running = false;
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
*proxy_guard = None;
|
||||
*sessions_opt = None;
|
||||
*udp_rx_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "manual stop");
|
||||
tx.send(UiEvent::TunnelStopped).await.ok();
|
||||
let stop_msg = if self.mode == "tun" { "TUN tunnel stopped" } else { "Bridge stopped" };
|
||||
tx.send(UiEvent::Log(stop_msg.to_string())).await.ok();
|
||||
} else {
|
||||
tx.send(UiEvent::Log("Connecting to remote server...".to_string())).await.ok();
|
||||
tx.send(UiEvent::Metrics { status: ConnectionStatus::Handshaking, rtt_ms: 0.0, throughput_bps: 0 }).await.ok();
|
||||
self.metrics.connection_state.store(1, Ordering::Relaxed);
|
||||
|
||||
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
|
||||
let (udp_tx, udp_rx) = mpsc::channel(100000);
|
||||
let mut sessions = Vec::with_capacity(session_count);
|
||||
let mut rtt_sum = 0.0;
|
||||
let mut successful_sessions = 0;
|
||||
|
||||
for idx in 0..session_count {
|
||||
let session_id: u32 = rand::thread_rng().gen();
|
||||
match self.perform_handshake_with_id(&tx, session_id).await {
|
||||
Ok((sock, mach, rtt)) => {
|
||||
let session_index = sessions.len();
|
||||
let socket_clone = sock.clone();
|
||||
let udp_tx_clone = udp_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0_u8; 65535];
|
||||
loop {
|
||||
match socket_clone.recv(&mut buf).await {
|
||||
Ok(n) => {
|
||||
let inbound = Bytes::copy_from_slice(&buf[..n]);
|
||||
if udp_tx_clone.send((session_index, inbound)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("UDP socket recv error (session {}): {}", session_index, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
sessions.push(SessionState { socket: sock, machine: mach });
|
||||
rtt_sum += rtt;
|
||||
successful_sessions += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
tx.send(UiEvent::Log(format!("Multiplex session {}/{} handshake failed: {}. Continuing with remaining sessions...", idx + 1, session_count, err))).await.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sessions.is_empty() {
|
||||
*proxy_guard = None;
|
||||
tx.send(UiEvent::Log("All multiplexed handshake attempts failed. Connection aborted.".to_string())).await.ok();
|
||||
tx.send(UiEvent::TunnelStopped).await.ok();
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
return True;
|
||||
}
|
||||
|
||||
*udp_rx_opt = Some(udp_rx);
|
||||
*sessions_opt = Some(sessions);
|
||||
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
|
||||
self.running = true;
|
||||
self.last_sample_at = Instant::now();
|
||||
self.last_valid_recv = Instant::now();
|
||||
|
||||
let sys_proxy_addr = self.proxy_addr.replace("0.0.0.0:", "127.0.0.1:");
|
||||
*proxy_guard = Some(crate::sysproxy::SystemProxyGuard::enable(&sys_proxy_addr));
|
||||
|
||||
tx.send(UiEvent::Metrics {
|
||||
status: ConnectionStatus::Established,
|
||||
rtt_ms: self.last_rtt_ms,
|
||||
throughput_bps: 0,
|
||||
}).await.ok();
|
||||
self.metrics.connection_state.store(2, Ordering::Relaxed);
|
||||
let start_msg = if self.mode == "tun" { "TUN tunnel established" } else { "Connection established" };
|
||||
tx.send(UiEvent::Log(start_msg.to_string())).await.ok();
|
||||
|
||||
for session in sessions_opt.as_mut().unwrap().iter_mut() {
|
||||
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
|
||||
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
|
||||
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
|
||||
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await;
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BridgeCommand::NextProfile) => {
|
||||
self.profile = next_profile(self.profile);
|
||||
tx.send(UiEvent::ProfileChanged(self.profile)).await.ok();
|
||||
tx.send(UiEvent::Log(format!("Obfuscation profile switched to {:?}", self.profile))).await.ok();
|
||||
}
|
||||
Some(BridgeCommand::NetworkChanged) => {
|
||||
if self.running {
|
||||
let _ = tx.send(UiEvent::Log("Network changed — starting immediate reconnect".to_string())).await;
|
||||
self.metrics.connection_state.store(1, Ordering::Relaxed);
|
||||
self.last_valid_recv = Instant::now() - Duration::from_secs(100);
|
||||
|
||||
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
|
||||
let (udp_tx, udp_rx) = mpsc::channel(100000);
|
||||
let mut new_sessions = Vec::with_capacity(session_count);
|
||||
let mut successful_sessions = 0;
|
||||
let mut rtt_sum = 0.0;
|
||||
|
||||
for idx in 0..session_count {
|
||||
let session_id: u32 = rand::thread_rng().gen();
|
||||
match self.perform_handshake_with_id(&tx, session_id).await {
|
||||
Ok((sock, mach, rtt)) => {
|
||||
let session_index = new_sessions.len();
|
||||
let socket_clone = sock.clone();
|
||||
let udp_tx_clone = udp_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0_u8; 65535];
|
||||
loop {
|
||||
match socket_clone.recv(&mut buf).await {
|
||||
Ok(n) => {
|
||||
let inbound = Bytes::copy_from_slice(&buf[..n]);
|
||||
if udp_tx_clone.send((session_index, inbound)).await.is_err() { break; }
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("UDP recv error (network-change session {}): {}", session_index, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
new_sessions.push(SessionState { socket: sock, machine: mach });
|
||||
rtt_sum += rtt;
|
||||
successful_sessions += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("NetworkChanged reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !new_sessions.is_empty() {
|
||||
*sessions_opt = Some(new_sessions);
|
||||
*udp_rx_opt = Some(udp_rx);
|
||||
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
|
||||
self.last_valid_recv = Instant::now();
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "network changed");
|
||||
self.metrics.connection_state.store(2, Ordering::Relaxed);
|
||||
let _ = tx.send(UiEvent::Log("NetworkChanged reconnect successful!".to_string())).await;
|
||||
} else {
|
||||
let _ = tx.send(UiEvent::Log("NetworkChanged reconnect failed — will retry on keepalive tick".to_string())).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BridgeCommand::ReloadConfig) => {
|
||||
match ClientConfig::reload_from_json_near_binary() {
|
||||
Ok(cfg) => {
|
||||
self.apply_runtime_config(&cfg);
|
||||
tx.send(UiEvent::Log("Runtime config reloaded".to_string())).await.ok();
|
||||
if self.running {
|
||||
self.running = false;
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
*proxy_guard = None;
|
||||
*sessions_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "config reload");
|
||||
let _ = tx.send(UiEvent::TunnelStopped).await;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Config reload failed: {err}"))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(BridgeCommand::Shutdown) | None => {
|
||||
self.running = false;
|
||||
*proxy_guard = None;
|
||||
return False;
|
||||
}
|
||||
}
|
||||
True
|
||||
}
|
||||
|
||||
async fn handle_keepalive(
|
||||
&mut self,
|
||||
sessions_opt: &mut Option<Vec<SessionState>>,
|
||||
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
|
||||
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
|
||||
stream_map: &mut std::collections::HashMap<u16, usize>,
|
||||
tx: &mpsc::Sender<UiEvent>,
|
||||
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
|
||||
proxy_rx: &mut mpsc::Receiver<ProxyEvent>,
|
||||
) {
|
||||
if self.last_valid_recv.elapsed().as_secs() > 25 {
|
||||
let elapsed = self.last_valid_recv.elapsed().as_secs();
|
||||
if elapsed > 180 {
|
||||
let _ = tx.send(UiEvent::Log("Connection permanently lost (3-minute hard timeout). Stopping tunnel.".into())).await;
|
||||
self.running = false;
|
||||
*proxy_guard = None;
|
||||
*sessions_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "keepalive hard timeout");
|
||||
let _ = tx.send(UiEvent::TunnelStopped).await;
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = tx.send(UiEvent::Log(format!("Connection stall detected ({}s silence). Attempting background reconnect...", elapsed))).await;
|
||||
self.metrics.connection_state.store(1, Ordering::Relaxed);
|
||||
|
||||
let session_count = if self.mux_enabled { self.mux_sessions.max(1) } else { 1 };
|
||||
let (udp_tx, udp_rx) = mpsc::channel(100000);
|
||||
let mut new_sessions = Vec::with_capacity(session_count);
|
||||
let mut successful_sessions = 0;
|
||||
let mut rtt_sum = 0.0;
|
||||
|
||||
for idx in 0..session_count {
|
||||
let session_id: u32 = rand::thread_rng().gen();
|
||||
match self.perform_handshake_with_id(&tx, session_id).await {
|
||||
Ok((sock, mach, rtt)) => {
|
||||
let session_index = new_sessions.len();
|
||||
let socket_clone = sock.clone();
|
||||
let udp_tx_clone = udp_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0_u8; 65535];
|
||||
loop {
|
||||
match socket_clone.recv(&mut buf).await {
|
||||
Ok(n) => {
|
||||
let inbound = Bytes::copy_from_slice(&buf[..n]);
|
||||
if udp_tx_clone.send((session_index, inbound)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("UDP socket recv error (reconnect session {}): {}", session_index, e);
|
||||
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
new_sessions.push(SessionState { socket: sock, machine: mach });
|
||||
rtt_sum += rtt;
|
||||
successful_sessions += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Background reconnect session {}/{} failed: {}", idx + 1, session_count, err))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !new_sessions.is_empty() {
|
||||
*sessions_opt = Some(new_sessions);
|
||||
*udp_rx_opt = Some(udp_rx);
|
||||
self.last_rtt_ms = rtt_sum / successful_sessions as f64;
|
||||
self.last_valid_recv = Instant::now();
|
||||
self.metrics.connection_state.store(2, Ordering::Relaxed);
|
||||
let _ = tx.send(UiEvent::Log("Background reconnect successful! Connection restored.".into())).await;
|
||||
|
||||
for session in sessions_opt.as_mut().unwrap().iter_mut() {
|
||||
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
|
||||
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
|
||||
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
|
||||
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp").await;
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "background reconnect");
|
||||
|
||||
let mut flushed = 0;
|
||||
while let Ok(stale) = proxy_rx.try_recv() {
|
||||
if let ProxyEvent::NewStream { stream_id, .. } = stale {
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("connection reset".into())));
|
||||
}
|
||||
flushed += 1;
|
||||
}
|
||||
if flushed > 0 {
|
||||
let _ = tx.send(UiEvent::Log(format!("Flushed {} stale proxy messages to prevent UDP burst", flushed))).await;
|
||||
}
|
||||
} else {
|
||||
let _ = tx.send(UiEvent::Log("Background reconnect failed. Will retry on next tick...".into())).await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sessions) = sessions_opt.as_mut() {
|
||||
for session in sessions.iter_mut() {
|
||||
let ts = SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as u64;
|
||||
let ping_payload = Bytes::from(RelayMessage::Ping(ts).encode());
|
||||
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ping_payload)) {
|
||||
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
let ka_payload = Bytes::from(RelayMessage::KeepAlive.encode());
|
||||
if let Ok(ProtocolAction::SendDatagram(frame)) = session.machine.on_event(OstpEvent::Outbound(0, ka_payload)) {
|
||||
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_retransmit(
|
||||
&mut self,
|
||||
sessions_opt: &mut Option<Vec<SessionState>>,
|
||||
udp_rx_opt: &mut Option<mpsc::Receiver<(usize, Bytes)>>,
|
||||
proxy_guard: &mut Option<crate::sysproxy::SystemProxyGuard>,
|
||||
stream_map: &mut std::collections::HashMap<u16, usize>,
|
||||
tx: &mpsc::Sender<UiEvent>,
|
||||
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
|
||||
) {
|
||||
let mut fatal_err = None;
|
||||
if let Some(sessions) = sessions_opt.as_mut() {
|
||||
for session in sessions.iter_mut() {
|
||||
match session.machine.on_event(OstpEvent::Tick) {
|
||||
Ok(action) => {
|
||||
let mut queue = vec![action];
|
||||
while let Some(current_action) = queue.pop() {
|
||||
match current_action {
|
||||
ProtocolAction::Multiple(nested) => {
|
||||
for a in nested {
|
||||
queue.push(a);
|
||||
}
|
||||
}
|
||||
ProtocolAction::SendDatagram(frame) => {
|
||||
let _ = send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await;
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
fatal_err = Some(e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(e) = fatal_err {
|
||||
let _ = tx.send(UiEvent::Log(format!("Protocol tick fatal error: {e}"))).await;
|
||||
self.running = false;
|
||||
*proxy_guard = None;
|
||||
*sessions_opt = None;
|
||||
*udp_rx_opt = None;
|
||||
stream_map.clear();
|
||||
self.reset_proxy_streams(&tx, &proxy_tx, "protocol fatal error");
|
||||
let _ = tx.send(UiEvent::TunnelStopped).await;
|
||||
self.metrics.connection_state.store(0, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_proxy_event(
|
||||
&mut self,
|
||||
proxy_ev: Option<ProxyEvent>,
|
||||
sessions_opt: &mut Option<Vec<SessionState>>,
|
||||
stream_map: &mut std::collections::HashMap<u16, usize>,
|
||||
tx: &mpsc::Sender<UiEvent>,
|
||||
proxy_tx: &mpsc::UnboundedSender<(u16, ProxyToClientMsg)>,
|
||||
) {
|
||||
if let Some(ev) = proxy_ev {
|
||||
if let Some(sessions) = sessions_opt.as_mut() {
|
||||
if sessions.is_empty() {
|
||||
if let ProxyEvent::NewStream { stream_id, .. } = ev {
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into())));
|
||||
}
|
||||
return;
|
||||
}
|
||||
let (stream_id, relay_msg, is_close) = match ev {
|
||||
ProxyEvent::NewStream { stream_id, target } => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Proxy CONNECT stream_id={stream_id} target={target}"))).await;
|
||||
(stream_id, RelayMessage::Connect(target), false)
|
||||
}
|
||||
ProxyEvent::UdpAssociate { stream_id } => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Proxy UDP ASSOCIATE stream_id={stream_id}"))).await;
|
||||
(stream_id, RelayMessage::UdpAssociate, false)
|
||||
}
|
||||
ProxyEvent::UdpData { stream_id, target, payload } => {
|
||||
(stream_id, RelayMessage::UdpData(target, payload.to_vec()), false)
|
||||
}
|
||||
ProxyEvent::Data { stream_id, payload } => (stream_id, RelayMessage::Data(payload.to_vec()), false),
|
||||
ProxyEvent::Close { stream_id } => {
|
||||
let _ = tx.send(UiEvent::Log(format!("Proxy CLOSE stream_id={stream_id}"))).await;
|
||||
(stream_id, RelayMessage::Close, true)
|
||||
}
|
||||
};
|
||||
let len = sessions.len();
|
||||
let session_index = *stream_map.entry(stream_id).or_insert_with(|| {
|
||||
rand::thread_rng().gen_range(0..len)
|
||||
});
|
||||
if is_close {
|
||||
stream_map.remove(&stream_id);
|
||||
}
|
||||
let session = &mut sessions[session_index];
|
||||
let out_payload = Bytes::from(relay_msg.encode());
|
||||
match session.machine.on_event(OstpEvent::Outbound(stream_id, out_payload)) {
|
||||
Ok(ProtocolAction::SendDatagram(frame)) => {
|
||||
if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() {
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
tracing::trace!("Outbound datagram sent stream_id={stream_id} bytes={}", frame.len());
|
||||
}
|
||||
}
|
||||
Ok(ProtocolAction::Multiple(list)) => {
|
||||
let mut sent = 0usize;
|
||||
for item in list {
|
||||
if let ProtocolAction::SendDatagram(frame) = item {
|
||||
if send_datagram(&session.socket, &frame, self.transport_mode == "udp" ).await.is_ok() {
|
||||
self.metrics.bytes_sent.fetch_add(frame.len() as u64, Ordering::Relaxed);
|
||||
sent += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::trace!("Outbound datagram batch stream_id={stream_id} sent={sent}");
|
||||
}
|
||||
Ok(ProtocolAction::Noop) => {
|
||||
tracing::trace!("Outbound datagram noop stream_id={stream_id}");
|
||||
}
|
||||
Ok(_) => {
|
||||
tracing::trace!("Outbound datagram unexpected action stream_id={stream_id}");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e);
|
||||
let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if let ProxyEvent::NewStream { stream_id, .. } = ev {
|
||||
let _ = proxy_tx.send((stream_id, ProxyToClientMsg::Error("tunnel stopped".into())));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
with open("d:/ospab-projects/ostp/ostp-client/src/bridge.rs", "w", encoding="utf-8") as f:
|
||||
f.write(prefix + new_run_and_helpers + suffix)
|
||||
|
||||
print("Done")
|
||||
Loading…
Reference in New Issue