mirror of https://github.com/ospab/ostp.git
346 lines
14 KiB
Rust
346 lines
14 KiB
Rust
// ostp-tun-helper/src/main.rs
|
|
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
|
|
|
|
use anyhow::Result;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
use std::io::Write as _;
|
|
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
|
use tokio::sync::{watch, Mutex};
|
|
use tokio::net::TcpListener;
|
|
use portable_atomic::Ordering;
|
|
|
|
fn log_to_file(msg: &str) {
|
|
let msg = msg.to_string();
|
|
tokio::task::spawn_blocking(move || {
|
|
let path = std::env::current_exe()
|
|
.ok()
|
|
.and_then(|p| p.parent().map(|d| d.join("ostp-helper.log")))
|
|
.unwrap_or_else(|| std::path::PathBuf::from("ostp-helper.log"));
|
|
if let Ok(mut file) = std::fs::OpenOptions::new().create(true).append(true).open(path) {
|
|
let _ = writeln!(file, "[{}] {}", chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), msg);
|
|
}
|
|
});
|
|
}
|
|
|
|
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(tag = "cmd", rename_all = "lowercase")]
|
|
enum GuiCmd {
|
|
Start { config: String, token: String },
|
|
Reload { config: String, token: String },
|
|
Stop { token: String },
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
#[serde(tag = "type", rename_all = "lowercase")]
|
|
#[allow(dead_code)]
|
|
enum HelperMsg {
|
|
Status { value: u8 },
|
|
Log { message: String },
|
|
Metrics { bytes_sent: u64, bytes_recv: u64, rtt_ms: u32 },
|
|
Error { message: String },
|
|
}
|
|
|
|
struct TunnelState {
|
|
shutdown_tx: Option<watch::Sender<bool>>,
|
|
metrics: Option<Arc<ostp_client::bridge::BridgeMetrics>>,
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
if let Ok(exe) = std::env::current_exe() {
|
|
if let Some(dir) = exe.parent() {
|
|
let _ = std::env::set_current_dir(dir);
|
|
}
|
|
}
|
|
|
|
let mut expected_token = String::new();
|
|
let mut port = 53211u16;
|
|
let args: Vec<String> = std::env::args().collect();
|
|
for i in 1..args.len() {
|
|
if args[i] == "--token" && i + 1 < args.len() {
|
|
expected_token = args[i + 1].clone();
|
|
} else if args[i] == "--port" && i + 1 < args.len() {
|
|
port = args[i + 1].parse().unwrap_or(53211);
|
|
}
|
|
}
|
|
|
|
log_to_file("Helper started (TCP mode)");
|
|
|
|
if expected_token.is_empty() {
|
|
log_to_file("FATAL: --token argument is required for security. Unauthorized access denied.");
|
|
return Err(anyhow::anyhow!("--token argument is required"));
|
|
}
|
|
|
|
if let Err(e) = run_server(expected_token, port).await {
|
|
log_to_file(&format!("Fatal error: {}", e));
|
|
}
|
|
log_to_file("Helper exiting");
|
|
Ok(())
|
|
}
|
|
|
|
async fn run_server(expected_token: String, port: u16) -> Result<()> {
|
|
let state = Arc::new(Mutex::new(TunnelState {
|
|
shutdown_tx: None,
|
|
metrics: None,
|
|
}));
|
|
|
|
let bind_addr = format!("127.0.0.1:{}", port);
|
|
log_to_file(&format!("Attempting to bind to {}", bind_addr));
|
|
let listener = TcpListener::bind(&bind_addr).await.map_err(|e| {
|
|
log_to_file(&format!("Bind failed: {}", e));
|
|
e
|
|
})?;
|
|
log_to_file("Listening successfully");
|
|
|
|
// Wait for GUI to connect (60 second timeout)
|
|
let (socket, _) = match tokio::time::timeout(Duration::from_secs(60), listener.accept()).await {
|
|
Ok(Ok(s)) => s,
|
|
_ => {
|
|
log_to_file("No connection from GUI within 60s, exiting");
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
log_to_file("GUI connected via TCP");
|
|
|
|
let (reader_half, writer_half) = tokio::io::split(socket);
|
|
let writer = Arc::new(Mutex::new(writer_half));
|
|
let mut reader = BufReader::new(reader_half);
|
|
|
|
let send_msg = {
|
|
let writer = writer.clone();
|
|
move |msg: HelperMsg| {
|
|
let writer = writer.clone();
|
|
let json = serde_json::to_string(&msg).unwrap_or_default();
|
|
tokio::spawn(async move {
|
|
let mut w = writer.lock().await;
|
|
let _ = w.write_all(format!("{}\n", json).as_bytes()).await;
|
|
});
|
|
}
|
|
};
|
|
|
|
let mut line = String::new();
|
|
loop {
|
|
line.clear();
|
|
let n = reader.read_line(&mut line).await.unwrap_or(0);
|
|
if n == 0 {
|
|
log_to_file("GUI disconnected, stopping tunnel");
|
|
let mut st = state.lock().await;
|
|
if let Some(tx) = st.shutdown_tx.take() {
|
|
let _ = tx.send(true);
|
|
}
|
|
break;
|
|
}
|
|
|
|
let trimmed = line.trim();
|
|
if trimmed.is_empty() { continue; }
|
|
|
|
let cmd: GuiCmd = match serde_json::from_str(trimmed) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
send_msg(HelperMsg::Error { message: format!("Bad command: {}", e) });
|
|
continue;
|
|
}
|
|
};
|
|
|
|
match cmd {
|
|
GuiCmd::Start { config, token } => {
|
|
if token != expected_token {
|
|
log_to_file("Received START command with invalid token");
|
|
send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() });
|
|
continue;
|
|
}
|
|
log_to_file("Received START command");
|
|
{
|
|
let mut st = state.lock().await;
|
|
if let Some(tx) = st.shutdown_tx.take() {
|
|
let _ = tx.send(true);
|
|
}
|
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
}
|
|
|
|
let cfg: ostp_client::config::ClientConfig = match serde_json::from_str(&config) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
log_to_file(&format!("Config parse error: {}", e));
|
|
send_msg(HelperMsg::Error { message: format!("Config parse error: {}", e) });
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let metrics = Arc::new(ostp_client::bridge::BridgeMetrics {
|
|
bytes_sent: portable_atomic::AtomicU64::new(0),
|
|
bytes_recv: portable_atomic::AtomicU64::new(0),
|
|
connection_state: portable_atomic::AtomicU8::new(0),
|
|
rtt_ms: portable_atomic::AtomicU32::new(0),
|
|
});
|
|
|
|
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
|
|
|
{
|
|
let mut st = state.lock().await;
|
|
st.shutdown_tx = Some(shutdown_tx);
|
|
st.metrics = Some(metrics.clone());
|
|
}
|
|
|
|
let metrics_for_runner = metrics.clone();
|
|
let writer_for_err = writer.clone();
|
|
let shutdown_rx_for_core = shutdown_rx.clone();
|
|
tokio::spawn(async move {
|
|
log_to_file("Starting tunnel core...");
|
|
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await {
|
|
Ok(_) => { log_to_file("Tunnel core stopped normally"); }
|
|
Err(e) => {
|
|
log_to_file(&format!("Tunnel core error: {}", e));
|
|
let json = serde_json::to_string(&HelperMsg::Error { message: e.to_string() }).unwrap_or_default();
|
|
let mut w = writer_for_err.lock().await;
|
|
let _ = w.write_all(format!("{}\n", json).as_bytes()).await;
|
|
}
|
|
}
|
|
});
|
|
|
|
let writer_tick = writer.clone();
|
|
let metrics_tick = metrics.clone();
|
|
let mut shutdown_rx_tick = shutdown_rx.clone();
|
|
tokio::spawn(async move {
|
|
let mut last_state = 99u8;
|
|
loop {
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(Duration::from_secs(1)) => {}
|
|
_ = shutdown_rx_tick.changed() => {
|
|
if *shutdown_rx_tick.borrow() { break; }
|
|
}
|
|
}
|
|
|
|
let cs = metrics_tick.connection_state.load(Ordering::Relaxed);
|
|
let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed);
|
|
let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed);
|
|
|
|
let rtt = metrics_tick.rtt_ms.load(Ordering::Relaxed);
|
|
|
|
let mut w = writer_tick.lock().await;
|
|
if cs != last_state {
|
|
last_state = cs;
|
|
let json = serde_json::to_string(&HelperMsg::Status { value: cs }).unwrap_or_default();
|
|
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
|
|
}
|
|
let json = serde_json::to_string(&HelperMsg::Metrics { bytes_sent: sent, bytes_recv: recv, rtt_ms: rtt }).unwrap_or_default();
|
|
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
|
|
drop(w);
|
|
}
|
|
});
|
|
|
|
send_msg(HelperMsg::Status { value: 1 });
|
|
}
|
|
GuiCmd::Reload { config, token } => {
|
|
if token != expected_token {
|
|
send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() });
|
|
continue;
|
|
}
|
|
log_to_file("Received RELOAD command");
|
|
|
|
// Signal shutdown to current core
|
|
{
|
|
let mut st = state.lock().await;
|
|
if let Some(tx) = st.shutdown_tx.take() {
|
|
let _ = tx.send(true);
|
|
}
|
|
tokio::time::sleep(Duration::from_millis(500)).await; // give it time to shutdown cleanly
|
|
}
|
|
|
|
let cfg: ostp_client::config::ClientConfig = match serde_json::from_str(&config) {
|
|
Ok(c) => c,
|
|
Err(e) => {
|
|
send_msg(HelperMsg::Error { message: format!("Config parse error during reload: {}", e) });
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let metrics = Arc::new(ostp_client::bridge::BridgeMetrics {
|
|
bytes_sent: portable_atomic::AtomicU64::new(0),
|
|
bytes_recv: portable_atomic::AtomicU64::new(0),
|
|
connection_state: portable_atomic::AtomicU8::new(0),
|
|
rtt_ms: portable_atomic::AtomicU32::new(0),
|
|
});
|
|
|
|
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
|
|
|
{
|
|
let mut st = state.lock().await;
|
|
st.shutdown_tx = Some(shutdown_tx);
|
|
st.metrics = Some(metrics.clone());
|
|
}
|
|
|
|
let metrics_for_runner = metrics.clone();
|
|
let writer_for_err = writer.clone();
|
|
let shutdown_rx_for_core = shutdown_rx.clone();
|
|
tokio::spawn(async move {
|
|
log_to_file("Restarting tunnel core for reload...");
|
|
match ostp_client::runner::run_client_core(cfg, metrics_for_runner, shutdown_rx_for_core).await {
|
|
Ok(_) => { log_to_file("Reloaded core stopped normally"); }
|
|
Err(e) => {
|
|
let json = serde_json::to_string(&HelperMsg::Error { message: e.to_string() }).unwrap_or_default();
|
|
let mut w = writer_for_err.lock().await;
|
|
let _ = w.write_all(format!("{}\n", json).as_bytes()).await;
|
|
}
|
|
}
|
|
});
|
|
|
|
// Status tick loop is already running and using old metrics?
|
|
// Wait! We re-created metrics, so the old tick loop will continue reporting old metrics (which are disconnected)!
|
|
// We should probably share the tick loop or spawn a new one and let the old one die.
|
|
// It's easier if `metrics` in state is a generic watcher, but since we re-spawned it:
|
|
let writer_tick = writer.clone();
|
|
let metrics_tick = metrics.clone();
|
|
let mut shutdown_rx_tick = shutdown_rx.clone();
|
|
tokio::spawn(async move {
|
|
let mut last_state = 99u8;
|
|
loop {
|
|
tokio::select! {
|
|
_ = tokio::time::sleep(Duration::from_secs(1)) => {}
|
|
_ = shutdown_rx_tick.changed() => {
|
|
if *shutdown_rx_tick.borrow() { break; }
|
|
}
|
|
}
|
|
let cs = metrics_tick.connection_state.load(Ordering::Relaxed);
|
|
let sent = metrics_tick.bytes_sent.load(Ordering::Relaxed);
|
|
let recv = metrics_tick.bytes_recv.load(Ordering::Relaxed);
|
|
let rtt = metrics_tick.rtt_ms.load(Ordering::Relaxed);
|
|
|
|
let mut w = writer_tick.lock().await;
|
|
if cs != last_state {
|
|
last_state = cs;
|
|
let json = serde_json::to_string(&HelperMsg::Status { value: cs }).unwrap_or_default();
|
|
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
|
|
}
|
|
let json = serde_json::to_string(&HelperMsg::Metrics { bytes_sent: sent, bytes_recv: recv, rtt_ms: rtt }).unwrap_or_default();
|
|
if w.write_all(format!("{}\n", json).as_bytes()).await.is_err() { break; }
|
|
drop(w);
|
|
}
|
|
});
|
|
|
|
send_msg(HelperMsg::Status { value: 1 });
|
|
}
|
|
GuiCmd::Stop { token } => {
|
|
if token != expected_token {
|
|
log_to_file("Received STOP command with invalid token");
|
|
send_msg(HelperMsg::Error { message: "Invalid authorization token".to_string() });
|
|
continue;
|
|
}
|
|
log_to_file("Received STOP command");
|
|
let mut st = state.lock().await;
|
|
if let Some(tx) = st.shutdown_tx.take() {
|
|
let _ = tx.send(true);
|
|
}
|
|
st.metrics = None;
|
|
send_msg(HelperMsg::Status { value: 0 });
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|