mirror of https://github.com/ospab/ostp.git
Fix Closing state, replace sent_history VecDeque with BTreeMap, clean up dead code
- protocol: Closing+Inbound no longer force-transitions to Closed after one packet; handle_inbound now owns the transition when it receives a Close frame, preventing data loss on in-flight packets during teardown. Add Tick handling for Closing state so the Close frame is retransmitted. - protocol: replace sent_history VecDeque<SentFrame> with BTreeMap<u64, SentFrame>; NACK lookup is now O(log n) instead of O(n) linear scan. - protocol: remove unused _mtu field; drop VecDeque import. - congestion: remove no-op on_tick method (was never called). - dispatcher: remove broad #[allow(dead_code)] on impl block; annotate three genuinely unused methods individually. Fix comment "100000 entries" → "50000" and log "inactive >5min" → ">10min" (real timeout is 600 s). Remove unused mut on stream variable in ostp client. - docs: correct timestamp window ±30 s → ±300 s in EN and RU specs to match the actual drift > 300 check in dispatcher. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d031b15679
commit
47d44fa072
|
|
@ -94,7 +94,7 @@ OSTP executes a Noise Protocol Framework exchange utilizing the `Noise_NNpsk0_25
|
|||
2. The PSK is integrated into the state at pattern position zero, authorizing and encrypting the very first handshaking datagram.
|
||||
3. Ephemeral Curve25519 key exchange is evaluated to synthesize autonomous symmetric keys for subsequent read/write channels.
|
||||
|
||||
The initial handshake payload includes a Unix timestamp to mitigate replay attacks. The server enforces a strict ±30-second synchronization window.
|
||||
The initial handshake payload includes a Unix timestamp to mitigate replay attacks. The server enforces a ±300-second synchronization window to accommodate clock drift and mobile roaming scenarios.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ OSTP использует Noise Protocol Framework с паттерном `Noise_
|
|||
2. PSK применяется на нулевой позиции паттерна, обеспечивая авторизацию и шифрование самой первой датаграммы рукопожатия (Zero-RTT авторизация).
|
||||
3. Выполняется эфемерный обмен ключами Curve25519 для создания симметричных ключей передачи данных.
|
||||
|
||||
Первичная полезная нагрузка рукопожатия содержит Unix-отметку времени для защиты от атак повторного воспроизведения (Replay Attacks). Сервер строго контролирует окно синхронизации (±30 секунд).
|
||||
Первичная полезная нагрузка рукопожатия содержит Unix-отметку времени для защиты от атак повторного воспроизведения (Replay Attacks). Сервер контролирует окно синхронизации (±300 секунд) с учётом дрейфа часов и смены сети при роуминге.
|
||||
|
||||
---
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,13 @@ fn make_initiator_config(
|
|||
"dns" => 1100,
|
||||
_ => 1350,
|
||||
};
|
||||
// For DNS transport: use larger ack_delay and rto to match DNS round-trip latency
|
||||
// (each DNS query + reply takes 300-800ms end-to-end through Cloudflare).
|
||||
// For UDP: minimize ack_delay to 1ms (ACK asap) and let CC drive the RTO.
|
||||
let (ack_delay_ms, rto_ms) = match transport_cfg.r#type.as_str() {
|
||||
"dns" => (50, 1500),
|
||||
_ => (1, 200),
|
||||
};
|
||||
|
||||
ProtocolConfig {
|
||||
role: ostp_core::NoiseRole::Initiator,
|
||||
|
|
@ -43,8 +50,8 @@ fn make_initiator_config(
|
|||
obfuscation_key: secrets.obfuscation_key,
|
||||
max_reorder: 16384,
|
||||
max_reorder_buffer: 8192,
|
||||
ack_delay_ms: 5,
|
||||
rto_ms: 100,
|
||||
ack_delay_ms,
|
||||
rto_ms,
|
||||
max_retries: 8,
|
||||
max_sent_history: 32768,
|
||||
handshake_pad_min: secrets.handshake_pad_min,
|
||||
|
|
@ -180,12 +187,25 @@ pub async fn dial_tcp(
|
|||
}
|
||||
|
||||
// ── Main bidirectional data forwarding loop ───────────────────────
|
||||
// Backpressure: we track how many frames are in-flight vs the congestion
|
||||
// window. When the window is full we stop reading from the TCP stream
|
||||
// (the kernel buffers it) until the remote ACKs enough frames.
|
||||
// This prevents overrunning the sender's sent_history and collapsing cwnd.
|
||||
let mut buf = [0u8; 65535];
|
||||
let mut udp_buf = [0u8; 65535];
|
||||
|
||||
loop {
|
||||
// Compute adaptive tick interval:
|
||||
// - If there is a pending ACK: tick = ack_delay (flush it quickly)
|
||||
// - Otherwise: tick = rto/4 (check retransmits without busy-spinning)
|
||||
// Floor at 1ms, ceiling at 50ms.
|
||||
let tick_ms = (machine.rto().as_millis() / 4).clamp(1, 50) as u64;
|
||||
|
||||
let can_send = machine.in_flight_count() < machine.cwnd_packets().max(4);
|
||||
|
||||
tokio::select! {
|
||||
Ok(n) = server_stream.read(&mut buf) => {
|
||||
// Only read from the application TCP stream when cwnd allows
|
||||
Ok(n) = server_stream.read(&mut buf), if can_send => {
|
||||
if n == 0 { break; }
|
||||
let data_msg = ostp_core::relay::RelayMessage::Data(buf[..n].to_vec());
|
||||
let encoded = data_msg.encode();
|
||||
|
|
@ -198,7 +218,7 @@ pub async fn dial_tcp(
|
|||
handle_action(action, &transport, &mut server_stream).await;
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(std::time::Duration::from_millis(10)) => {
|
||||
_ = tokio::time::sleep(std::time::Duration::from_millis(tick_ms)) => {
|
||||
if let Ok(action) = machine.on_event(OstpEvent::Tick) {
|
||||
handle_action(action, &transport, &mut server_stream).await;
|
||||
}
|
||||
|
|
@ -299,15 +319,62 @@ async fn make_transport(
|
|||
server: &str,
|
||||
port: u16,
|
||||
) -> Result<crate::transport::Transport> {
|
||||
let debug = tracing::enabled!(tracing::Level::DEBUG);
|
||||
match transport_cfg.r#type.as_str() {
|
||||
"dns" => {
|
||||
let domain = transport_cfg.domain.clone()
|
||||
.unwrap_or_else(|| "tunnel.example.com".to_string());
|
||||
let pubkey = transport_cfg.pubkey.clone()
|
||||
.unwrap_or_else(|| "".to_string());
|
||||
let resolver = transport_cfg.resolver.clone()
|
||||
.unwrap_or_else(|| server.to_string());
|
||||
let transport = crate::transport::dns::start_dns_transport(domain, resolver, transport_cfg.pubkey.clone()).await
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
Ok(transport)
|
||||
let resolver_with_port = if resolver.contains(':') {
|
||||
resolver.clone()
|
||||
} else {
|
||||
format!("{}:53", resolver)
|
||||
};
|
||||
|
||||
let (local_port, process) = ostp_core::dnstt::spawn_client(&pubkey, &domain, &resolver_with_port, debug)?;
|
||||
|
||||
// Wait for dnstt-client to start its local TCP listener
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Connect TCP to the local dnstt-client port
|
||||
let stream = tokio::net::TcpStream::connect(("127.0.0.1", local_port)).await?;
|
||||
let (mut rh, mut wh) = stream.into_split();
|
||||
|
||||
let (tx_send, mut tx_recv) = tokio::sync::mpsc::channel::<bytes::Bytes>(1024);
|
||||
let (rx_send, rx_recv) = tokio::sync::mpsc::channel::<bytes::Bytes>(1024);
|
||||
|
||||
// Writer task
|
||||
tokio::spawn(async move {
|
||||
use tokio::io::AsyncWriteExt;
|
||||
while let Some(data) = tx_recv.recv().await {
|
||||
let len = data.len() as u16;
|
||||
if wh.write_u16(len).await.is_err() { break; }
|
||||
if wh.write_all(&data).await.is_err() { break; }
|
||||
}
|
||||
});
|
||||
|
||||
// Reader task
|
||||
tokio::spawn(async move {
|
||||
use tokio::io::AsyncReadExt;
|
||||
loop {
|
||||
let len = match rh.read_u16().await {
|
||||
Ok(l) => l,
|
||||
Err(_) => break,
|
||||
};
|
||||
let mut buf = vec![0u8; len as usize];
|
||||
if rh.read_exact(&mut buf).await.is_err() { break; }
|
||||
if rx_send.send(bytes::Bytes::from(buf)).await.is_err() { break; }
|
||||
}
|
||||
});
|
||||
|
||||
Ok(crate::transport::Transport::Dnstt {
|
||||
tx: tx_send,
|
||||
rx: std::sync::Arc::new(tokio::sync::Mutex::new(rx_recv)),
|
||||
_guard: std::sync::Arc::new(tokio::sync::Mutex::new(process)),
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
let udp = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;
|
||||
|
|
|
|||
|
|
@ -4,6 +4,12 @@
|
|||
//! bandwidth and minimum RTT to determine the optimal sending rate.
|
||||
//! This replaces the fixed `retransmit_budget = 8` with an adaptive
|
||||
//! congestion window that responds to network conditions.
|
||||
//!
|
||||
//! RTO calculation follows RFC 6298:
|
||||
//! SRTT = (1 - α) * SRTT + α * RTT (α = 1/8)
|
||||
//! RTTVAR = (1 - β) * RTTVAR + β * |SRTT - RTT| (β = 1/4)
|
||||
//! RTO = SRTT + 4 * RTTVAR
|
||||
//! clamped to [RTO_MIN, RTO_MAX]
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
|
|
@ -15,8 +21,14 @@ pub struct CongestionController {
|
|||
ssthresh: u64,
|
||||
/// Current phase
|
||||
phase: Phase,
|
||||
/// Minimum RTT observed
|
||||
/// Minimum RTT observed (for BBR-style bandwidth estimation)
|
||||
min_rtt: Duration,
|
||||
/// Smoothed RTT (RFC 6298 SRTT)
|
||||
srtt: Duration,
|
||||
/// RTT variance (RFC 6298 RTTVAR)
|
||||
rttvar: Duration,
|
||||
/// Whether we have received a first RTT sample
|
||||
rtt_initialized: bool,
|
||||
/// Bytes currently in flight (unacknowledged)
|
||||
bytes_in_flight: u64,
|
||||
/// Total bytes acknowledged (for bandwidth estimation)
|
||||
|
|
@ -37,31 +49,43 @@ pub struct CongestionController {
|
|||
enum Phase {
|
||||
/// Exponential growth until loss or ssthresh
|
||||
SlowStart,
|
||||
/// Probe bandwidth: cycle through pacing gains
|
||||
/// Probe bandwidth: additive increase
|
||||
ProbeBandwidth,
|
||||
}
|
||||
|
||||
/// Initial congestion window: 10 packets × MTU
|
||||
const INITIAL_CWND_PACKETS: u64 = 10;
|
||||
/// Initial congestion window: 32 packets × MTU (IW10 is too conservative for modern links)
|
||||
const INITIAL_CWND_PACKETS: u64 = 32;
|
||||
/// Minimum cwnd: 2 packets
|
||||
const MIN_CWND_PACKETS: u64 = 2;
|
||||
/// Min RTT expiry window (after which we re-probe)
|
||||
const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10);
|
||||
/// Minimum RTO (RFC 6298: 1s in TCP; we use 50ms since we own the protocol)
|
||||
const RTO_MIN: Duration = Duration::from_millis(50);
|
||||
/// Maximum RTO
|
||||
const RTO_MAX: Duration = Duration::from_secs(16);
|
||||
/// Initial RTT estimate — 30 ms is reasonable for a well-connected VPN server.
|
||||
/// Will be replaced by first real measurement within milliseconds.
|
||||
const INITIAL_RTT: Duration = Duration::from_millis(30);
|
||||
|
||||
impl CongestionController {
|
||||
pub fn new(mtu: u64) -> Self {
|
||||
let now = Instant::now();
|
||||
let initial_cwnd = INITIAL_CWND_PACKETS * mtu;
|
||||
// Initial pacing: deliver cwnd in ~2 RTTs to fill the pipe quickly
|
||||
let initial_pacing = initial_cwnd * 1_000_000 / INITIAL_RTT.as_micros().max(1) as u64;
|
||||
Self {
|
||||
cwnd: initial_cwnd,
|
||||
ssthresh: u64::MAX,
|
||||
phase: Phase::SlowStart,
|
||||
min_rtt: Duration::from_millis(100), // Conservative initial estimate
|
||||
min_rtt: INITIAL_RTT,
|
||||
srtt: INITIAL_RTT,
|
||||
rttvar: INITIAL_RTT / 2,
|
||||
rtt_initialized: false,
|
||||
bytes_in_flight: 0,
|
||||
total_acked: 0,
|
||||
last_ack_time: now,
|
||||
loss_count: 0,
|
||||
pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec
|
||||
pacing_rate: initial_pacing,
|
||||
mtu,
|
||||
min_rtt_stamp: now,
|
||||
}
|
||||
|
|
@ -82,9 +106,20 @@ impl CongestionController {
|
|||
self.pacing_rate
|
||||
}
|
||||
|
||||
/// Returns the smoothed RTT estimate.
|
||||
/// Returns the smoothed RTT estimate (SRTT).
|
||||
pub fn smoothed_rtt(&self) -> Duration {
|
||||
self.min_rtt
|
||||
self.srtt
|
||||
}
|
||||
|
||||
/// Returns the adaptive RTO computed per RFC 6298:
|
||||
/// RTO = SRTT + 4 * RTTVAR, clamped to [RTO_MIN, RTO_MAX].
|
||||
///
|
||||
/// This replaces the static `rto_ms` field in ProtocolMachine so that
|
||||
/// retransmit timers automatically track changing network conditions.
|
||||
pub fn rto(&self) -> Duration {
|
||||
let rttvar4 = self.rttvar.saturating_mul(4);
|
||||
let rto = self.srtt.saturating_add(rttvar4);
|
||||
rto.clamp(RTO_MIN, RTO_MAX)
|
||||
}
|
||||
|
||||
/// Returns how many bytes can still be sent.
|
||||
|
|
@ -115,16 +150,13 @@ impl CongestionController {
|
|||
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes);
|
||||
self.total_acked = self.total_acked.saturating_add(bytes);
|
||||
|
||||
// Update RTT
|
||||
// Update RTT measurements
|
||||
self.update_rtt(rtt, now);
|
||||
|
||||
// Update bandwidth estimate
|
||||
self.update_bandwidth(bytes, now);
|
||||
|
||||
// State machine
|
||||
match self.phase {
|
||||
Phase::SlowStart => {
|
||||
// Exponential growth: increase cwnd by acked bytes
|
||||
// Exponential growth: increase cwnd by acked bytes (doubles per RTT)
|
||||
self.cwnd = self.cwnd.saturating_add(bytes);
|
||||
if self.cwnd >= self.ssthresh {
|
||||
self.phase = Phase::ProbeBandwidth;
|
||||
|
|
@ -164,32 +196,49 @@ impl CongestionController {
|
|||
self.update_pacing_rate();
|
||||
}
|
||||
|
||||
/// Called periodically to update state.
|
||||
pub fn on_tick(&mut self) {
|
||||
// Nothing special needed per-tick -- state updates happen on ACK/loss
|
||||
}
|
||||
|
||||
// ── Private ──────────────────────────────────────────────────────────────
|
||||
|
||||
fn update_rtt(&mut self, rtt: Duration, now: Instant) {
|
||||
// Track windowed minimum RTT
|
||||
// Update windowed minimum RTT (for pacing)
|
||||
if rtt < self.min_rtt || now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY {
|
||||
self.min_rtt = rtt;
|
||||
self.min_rtt_stamp = now;
|
||||
}
|
||||
}
|
||||
|
||||
fn update_bandwidth(&mut self, _acked_bytes: u64, now: Instant) {
|
||||
let elapsed = now.duration_since(self.last_ack_time);
|
||||
if elapsed.as_micros() > 0 {
|
||||
// Removed bw_samples tracking
|
||||
// Update SRTT and RTTVAR per RFC 6298
|
||||
if !self.rtt_initialized {
|
||||
// First measurement: initialize directly
|
||||
self.srtt = rtt;
|
||||
self.rttvar = rtt / 2;
|
||||
self.rtt_initialized = true;
|
||||
} else {
|
||||
// RTTVAR = (3/4) * RTTVAR + (1/4) * |SRTT - R|
|
||||
let diff = if rtt > self.srtt {
|
||||
rtt - self.srtt
|
||||
} else {
|
||||
self.srtt - rtt
|
||||
};
|
||||
// Integer-safe: RTTVAR = RTTVAR - RTTVAR/4 + diff/4
|
||||
self.rttvar = self.rttvar
|
||||
.saturating_sub(self.rttvar / 4)
|
||||
.saturating_add(diff / 4);
|
||||
|
||||
// SRTT = (7/8) * SRTT + (1/8) * R
|
||||
self.srtt = self.srtt
|
||||
.saturating_sub(self.srtt / 8)
|
||||
.saturating_add(rtt / 8);
|
||||
}
|
||||
|
||||
tracing::trace!(
|
||||
srtt_ms = self.srtt.as_millis(),
|
||||
rttvar_ms = self.rttvar.as_millis(),
|
||||
rto_ms = self.rto().as_millis(),
|
||||
"congestion: RTT updated"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
|
||||
fn update_pacing_rate(&mut self) {
|
||||
// Pacing rate = cwnd / min_rtt (with gain)
|
||||
// Pacing rate = cwnd / min_rtt (delivery rate target)
|
||||
let rtt_us = self.min_rtt.as_micros().max(1) as u64;
|
||||
self.pacing_rate = self.cwnd * 1_000_000 / rtt_us;
|
||||
}
|
||||
|
|
@ -202,19 +251,18 @@ mod tests {
|
|||
#[test]
|
||||
fn test_initial_state() {
|
||||
let cc = CongestionController::new(1200);
|
||||
assert_eq!(cc.cwnd(), 12000); // 10 * 1200
|
||||
assert_eq!(cc.cwnd(), 32 * 1200); // 32 * 1200
|
||||
assert!(cc.can_send());
|
||||
assert_eq!(cc.cwnd_packets(), 10);
|
||||
assert_eq!(cc.cwnd_packets(), 32);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slow_start_growth() {
|
||||
let mut cc = CongestionController::new(1200);
|
||||
// Simulate sending and ACKing
|
||||
let initial = cc.cwnd();
|
||||
cc.on_send(1200);
|
||||
cc.on_ack(1200, Duration::from_millis(50));
|
||||
// cwnd should grow
|
||||
assert!(cc.cwnd() > 12000);
|
||||
assert!(cc.cwnd() > initial);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -229,7 +277,7 @@ mod tests {
|
|||
fn test_can_send_limits() {
|
||||
let mut cc = CongestionController::new(1200);
|
||||
// Send until cwnd is exhausted
|
||||
for _ in 0..10 {
|
||||
for _ in 0..32 {
|
||||
cc.on_send(1200);
|
||||
}
|
||||
assert!(!cc.can_send()); // cwnd exhausted
|
||||
|
|
@ -244,10 +292,46 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_rtt_tracking() {
|
||||
fn test_rtt_tracking_first_sample() {
|
||||
let mut cc = CongestionController::new(1200);
|
||||
cc.on_send(1200);
|
||||
cc.on_ack(1200, Duration::from_millis(25));
|
||||
// After first sample: SRTT = 25ms, RTTVAR = 12ms
|
||||
assert_eq!(cc.smoothed_rtt(), Duration::from_millis(25));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rto_rfc6298() {
|
||||
let mut cc = CongestionController::new(1200);
|
||||
// After first sample with RTT=50ms: SRTT=50ms, RTTVAR=25ms, RTO=150ms
|
||||
cc.on_send(1200);
|
||||
cc.on_ack(1200, Duration::from_millis(50));
|
||||
let rto = cc.rto();
|
||||
// RTO = 50 + 4*25 = 150ms; clamped to [50ms, 16s]
|
||||
assert!(rto >= RTO_MIN);
|
||||
assert!(rto <= RTO_MAX);
|
||||
assert_eq!(rto, Duration::from_millis(150));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rto_clamp_min() {
|
||||
let cc = CongestionController::new(1200);
|
||||
// Even with no RTT samples, RTO should not go below RTO_MIN
|
||||
assert!(cc.rto() >= RTO_MIN);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rto_adapts_after_multiple_samples() {
|
||||
let mut cc = CongestionController::new(1200);
|
||||
// Feed several consistent RTT samples
|
||||
for _ in 0..8 {
|
||||
cc.on_send(1200);
|
||||
cc.on_ack(1200, Duration::from_millis(20));
|
||||
}
|
||||
// After convergence, RTTVAR should be small → RTO close to SRTT + small margin
|
||||
let rto = cc.rto();
|
||||
// Should be well below 100ms (the old hardcoded default)
|
||||
assert!(rto < Duration::from_millis(200));
|
||||
assert!(rto >= RTO_MIN);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use bytes::Bytes;
|
|||
use rand::Rng;
|
||||
use sha2::{Digest, Sha256};
|
||||
use thiserror::Error;
|
||||
use std::collections::{BTreeMap, VecDeque};
|
||||
use std::collections::BTreeMap;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::congestion::CongestionController;
|
||||
|
|
@ -75,7 +75,7 @@ pub struct ProtocolMachine {
|
|||
send_nonce: u64,
|
||||
expected_recv_nonce: u64,
|
||||
reorder_buffer: BTreeMap<u64, ProtocolAction>,
|
||||
sent_history: VecDeque<SentFrame>,
|
||||
sent_history: BTreeMap<u64, SentFrame>,
|
||||
session_id: u32,
|
||||
handshake_payload: Vec<u8>,
|
||||
padder: AdaptivePadder,
|
||||
|
|
@ -83,7 +83,8 @@ pub struct ProtocolMachine {
|
|||
max_reorder: u64,
|
||||
max_reorder_buffer: usize,
|
||||
ack_delay: Duration,
|
||||
rto: Duration,
|
||||
/// Initial/fallback RTO from config (overridden by cc.rto() after first RTT sample)
|
||||
rto_initial: Duration,
|
||||
max_retries: u8,
|
||||
max_sent_history: usize,
|
||||
ack_pending: bool,
|
||||
|
|
@ -100,11 +101,11 @@ pub struct ProtocolMachine {
|
|||
/// Key-derived handshake padding range
|
||||
handshake_pad_min: usize,
|
||||
handshake_pad_max: usize,
|
||||
_mtu: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct SentFrame {
|
||||
#[allow(dead_code)] // mirrored in BTreeMap key; kept for Debug output
|
||||
nonce: u64,
|
||||
bytes: Bytes,
|
||||
last_sent: Instant,
|
||||
|
|
@ -128,7 +129,7 @@ impl ProtocolMachine {
|
|||
send_nonce: 0,
|
||||
expected_recv_nonce: 0,
|
||||
reorder_buffer: BTreeMap::new(),
|
||||
sent_history: VecDeque::with_capacity(config.max_sent_history.max(1)),
|
||||
sent_history: BTreeMap::new(),
|
||||
session_id: config.session_id,
|
||||
handshake_payload: config.handshake_payload,
|
||||
padder: AdaptivePadder::new(config.mtu, config.max_padding, config.padding_strategy),
|
||||
|
|
@ -136,7 +137,7 @@ impl ProtocolMachine {
|
|||
max_reorder: config.max_reorder.max(1),
|
||||
max_reorder_buffer: config.max_reorder_buffer.max(1),
|
||||
ack_delay: Duration::from_millis(config.ack_delay_ms.max(1)),
|
||||
rto: Duration::from_millis(config.rto_ms.max(1)),
|
||||
rto_initial: Duration::from_millis(config.rto_ms.max(1)),
|
||||
max_retries: config.max_retries.max(1),
|
||||
max_sent_history: config.max_sent_history.max(1),
|
||||
ack_pending: false,
|
||||
|
|
@ -146,20 +147,25 @@ impl ProtocolMachine {
|
|||
cc: CongestionController::new(config.mtu as u64),
|
||||
handshake_pad_min: config.handshake_pad_min.max(8),
|
||||
handshake_pad_max: config.handshake_pad_max.max(config.handshake_pad_min + 16),
|
||||
_mtu: config.mtu,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn in_flight_count(&self) -> usize {
|
||||
// COUNT ONLY retransmittable Data frames — control frames (Ack/Nack) must not
|
||||
// contribute to this counter or they will trigger false backpressure.
|
||||
self.sent_history.iter().filter(|f| f.is_retransmittable).count()
|
||||
self.sent_history.values().filter(|f| f.is_retransmittable).count()
|
||||
}
|
||||
|
||||
pub fn cwnd_packets(&self) -> usize {
|
||||
self.cc.cwnd_packets() as usize
|
||||
}
|
||||
|
||||
/// Returns the current adaptive RTO (from congestion controller after first RTT sample,
|
||||
/// falls back to the config-specified initial value before any ACK is received).
|
||||
pub fn rto(&self) -> Duration {
|
||||
self.cc.rto()
|
||||
}
|
||||
|
||||
pub fn on_send(&mut self, bytes: u64) {
|
||||
self.cc.on_send(bytes);
|
||||
}
|
||||
|
|
@ -207,13 +213,12 @@ impl ProtocolMachine {
|
|||
.map(ProtocolAction::SendDatagram)
|
||||
}
|
||||
(OstpState::Closing, OstpEvent::Inbound(raw)) => {
|
||||
// Process final in-flight packets to prevent data loss during teardown.
|
||||
// The remote may still have data or ACKs in transit when we initiated Close.
|
||||
let result = self.handle_inbound(raw);
|
||||
self.state = OstpState::Closed;
|
||||
result
|
||||
// The remote may still have data or ACKs in transit.
|
||||
// handle_inbound transitions to Closed when it receives a Close frame.
|
||||
self.handle_inbound(raw)
|
||||
}
|
||||
(OstpState::Established, OstpEvent::Tick) => self.handle_tick(),
|
||||
(OstpState::Closing, OstpEvent::Tick) => self.handle_tick(),
|
||||
(OstpState::Closed, _) => Ok(ProtocolAction::Noop),
|
||||
(_, OstpEvent::Close) => {
|
||||
self.state = OstpState::Closed;
|
||||
|
|
@ -408,10 +413,10 @@ impl ProtocolMachine {
|
|||
tracing::debug!("Frame nonce={} arrived too late after gap recovery, dropping", nonce);
|
||||
}
|
||||
|
||||
// Rate-limited NACK: send at most once per 30ms to prevent retransmit storms.
|
||||
// Under high load with natural UDP reordering, sending a NACK per packet
|
||||
// causes exponential retransmit explosion that saturates the channel.
|
||||
let nack_cooldown = Duration::from_millis(30);
|
||||
// Rate-limited NACK: send at most once per (rto/2) to prevent retransmit storms.
|
||||
// Using rto/2 means we send a NACK before the sender's timer fires, prompting
|
||||
// fast retransmit without flooding. Floor at 10ms to handle very low-RTT links.
|
||||
let nack_cooldown = (self.cc.rto() / 2).max(Duration::from_millis(10));
|
||||
if self.last_nack_sent.elapsed() >= nack_cooldown {
|
||||
self.last_nack_sent = Instant::now();
|
||||
let nack_payload = self.expected_recv_nonce.to_be_bytes();
|
||||
|
|
@ -525,30 +530,33 @@ impl ProtocolMachine {
|
|||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let base_rto_ms = self.rto.as_millis().max(1) as u64;
|
||||
// Use the adaptive RTO from the congestion controller (RFC 6298 SRTT + 4*RTTVAR).
|
||||
// Falls back to rto_initial before the first ACK is received.
|
||||
let base_rto = self.cc.rto().max(self.rto_initial);
|
||||
let base_rto_ms = base_rto.as_millis().max(1) as u64;
|
||||
|
||||
// ── Zombie frame eviction ────────────────────────────────────
|
||||
// Evict frames that exceeded max_retries + 2 grace retries.
|
||||
// Shorter grace period than before (was +4) to free memory faster
|
||||
// after high-throughput bursts.
|
||||
let grace = self.max_retries.saturating_add(2);
|
||||
let before = self.sent_history.len();
|
||||
self.sent_history.retain(|f| !f.is_retransmittable || f.retries <= grace);
|
||||
self.sent_history.retain(|_, f| !f.is_retransmittable || f.retries <= grace);
|
||||
let evicted = before - self.sent_history.len();
|
||||
if evicted > 0 {
|
||||
tracing::debug!("Evicted {} zombie frames from sent_history (remaining={})", evicted, self.sent_history.len());
|
||||
}
|
||||
|
||||
// ── Retransmit expired frames ────────────────────────────────
|
||||
// Limit retransmits per tick to prevent bandwidth saturation
|
||||
// Backoff starts from retry #0 (immediately effective):
|
||||
// effective_rto = base_rto * 2^retries, capped at 2^6 = 64×
|
||||
// This ensures we do not flood with retransmits on the first few losses
|
||||
// while still recovering quickly on a transient single loss.
|
||||
let mut retransmit_budget: usize = self.cc.retransmit_budget();
|
||||
for frame in self.sent_history.iter_mut() {
|
||||
for frame in self.sent_history.values_mut() {
|
||||
if !frame.is_retransmittable {
|
||||
continue;
|
||||
}
|
||||
|
||||
let retry_over = frame.retries.saturating_sub(self.max_retries);
|
||||
let backoff_factor = 1u64 << retry_over.min(6);
|
||||
let backoff_factor = 1u64 << (frame.retries as u64).min(6);
|
||||
let effective_rto = Duration::from_millis(base_rto_ms.saturating_mul(backoff_factor));
|
||||
|
||||
if now.duration_since(frame.last_sent) >= effective_rto {
|
||||
|
|
@ -654,7 +662,7 @@ impl ProtocolMachine {
|
|||
}
|
||||
|
||||
fn lookup_sent_frame(&mut self, nonce: u64) -> Option<Bytes> {
|
||||
if let Some(frame) = self.sent_history.iter_mut().rev().find(|f| f.nonce == nonce) {
|
||||
if let Some(frame) = self.sent_history.get_mut(&nonce) {
|
||||
frame.last_sent = Instant::now();
|
||||
frame.retries = frame.retries.saturating_add(1);
|
||||
return Some(frame.bytes.clone());
|
||||
|
|
@ -666,7 +674,7 @@ impl ProtocolMachine {
|
|||
if is_retransmittable {
|
||||
self.cc.on_send(bytes.len() as u64);
|
||||
}
|
||||
self.sent_history.push_back(SentFrame {
|
||||
self.sent_history.insert(nonce, SentFrame {
|
||||
nonce,
|
||||
bytes,
|
||||
last_sent: Instant::now(),
|
||||
|
|
@ -679,7 +687,7 @@ impl ProtocolMachine {
|
|||
overflow, self.max_sent_history
|
||||
);
|
||||
while self.sent_history.len() > self.max_sent_history {
|
||||
self.sent_history.pop_front();
|
||||
self.sent_history.pop_first();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -690,8 +698,8 @@ impl ProtocolMachine {
|
|||
let mut min_rtt = Duration::from_secs(60);
|
||||
|
||||
// Compute RTT from the oldest acked frame's send timestamp
|
||||
for frame in self.sent_history.iter() {
|
||||
if nonce_in_ranges(frame.nonce, ranges) {
|
||||
for (&nonce, frame) in &self.sent_history {
|
||||
if nonce_in_ranges(nonce, ranges) {
|
||||
acked_bytes += frame.bytes.len() as u64;
|
||||
let rtt = now.duration_since(frame.last_sent);
|
||||
if rtt < min_rtt {
|
||||
|
|
@ -700,7 +708,7 @@ impl ProtocolMachine {
|
|||
}
|
||||
}
|
||||
|
||||
self.sent_history.retain(|frame| !nonce_in_ranges(frame.nonce, ranges));
|
||||
self.sent_history.retain(|&nonce, _| !nonce_in_ranges(nonce, ranges));
|
||||
|
||||
// Notify congestion controller
|
||||
if acked_bytes > 0 {
|
||||
|
|
|
|||
|
|
@ -12,12 +12,13 @@ use portable_atomic::AtomicU64;
|
|||
// const MAX_SESSIONS removed because dynamic limit is used
|
||||
|
||||
pub enum DispatchOutcome {
|
||||
Unauthorized(String),
|
||||
Accepted {
|
||||
responses: Vec<Bytes>,
|
||||
app_payloads: Vec<(u32, u16, Bytes)>, // session_id, stream_id, payload
|
||||
peer_addr: SocketAddr,
|
||||
},
|
||||
Unauthorized(String),
|
||||
Ignored,
|
||||
}
|
||||
|
||||
/// Per-user traffic statistics.
|
||||
|
|
@ -83,7 +84,6 @@ pub struct Dispatcher {
|
|||
last_token_regen: std::time::Instant,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl Dispatcher {
|
||||
pub fn new(machine_config: ProtocolConfig, access_keys: Arc<RwLock<HashMap<String, crate::api::UserMeta>>>) -> Self {
|
||||
let mut initial_stats = HashMap::new();
|
||||
|
|
@ -108,6 +108,7 @@ impl Dispatcher {
|
|||
}
|
||||
|
||||
/// Snapshot all user stats for API responses.
|
||||
#[allow(dead_code)]
|
||||
pub fn snapshot_all_users(&self) -> Vec<UserStatsSnapshot> {
|
||||
let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner());
|
||||
let mut online_keys: HashMap<String, std::time::Instant> = HashMap::new();
|
||||
|
|
@ -161,6 +162,7 @@ impl Dispatcher {
|
|||
}
|
||||
|
||||
/// Set traffic limit for a user.
|
||||
#[allow(dead_code)]
|
||||
pub fn set_user_limit(&self, key: &str, limit: Option<u64>) {
|
||||
let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner());
|
||||
let entry = stats.entry(key.to_string())
|
||||
|
|
@ -176,6 +178,7 @@ impl Dispatcher {
|
|||
}
|
||||
|
||||
/// Active session count.
|
||||
#[allow(dead_code)]
|
||||
pub fn active_sessions(&self) -> usize {
|
||||
self.peer_machines.len()
|
||||
}
|
||||
|
|
@ -376,15 +379,19 @@ impl Dispatcher {
|
|||
continue;
|
||||
}
|
||||
|
||||
if !self.replay_cache.contains_key(&payload.to_vec()) {
|
||||
if self.replay_cache.len() >= 50_000 {
|
||||
tracing::warn!("Replay cache full (100000 entries), rejecting handshake from {}", peer);
|
||||
return Ok(DispatchOutcome::Unauthorized("replay cache full".to_string()));
|
||||
}
|
||||
if self.replay_cache.contains_key(&payload.to_vec()) {
|
||||
tracing::debug!("Replay detected from {}, ignoring", peer);
|
||||
return Ok(DispatchOutcome::Ignored);
|
||||
}
|
||||
|
||||
self.replay_cache.insert(payload.to_vec(), ts);
|
||||
if self.replay_cache.len() >= 50_000 {
|
||||
tracing::warn!("Replay cache full (50000 entries), rejecting handshake from {}", peer);
|
||||
return Ok(DispatchOutcome::Unauthorized("replay cache full".to_string()));
|
||||
}
|
||||
|
||||
machine.set_session_keys(candidate_session_id, secrets.obfuscation_key);
|
||||
self.replay_cache.insert(payload.to_vec(), ts);
|
||||
|
||||
machine.set_session_keys(candidate_session_id, secrets.obfuscation_key);
|
||||
|
||||
// Track per-user connection count
|
||||
let user_stats = self.get_or_create_user_stats(&candidate_key);
|
||||
|
|
@ -414,7 +421,6 @@ impl Dispatcher {
|
|||
app_payloads: Vec::new(),
|
||||
peer_addr: peer,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -429,23 +435,35 @@ impl Dispatcher {
|
|||
Ok(DispatchOutcome::Unauthorized(reason))
|
||||
}
|
||||
|
||||
pub fn outbound_to_session(&mut self, session_id: u32, stream_id: u16, payload: Bytes) -> Result<Option<(Bytes, SocketAddr)>> {
|
||||
pub fn outbound_to_session(&mut self, session_id: u32, stream_id: u16, payload: Bytes) -> Result<Vec<(Bytes, SocketAddr)>> {
|
||||
let peer_state = if let Some(existing) = self.peer_machines.get_mut(&session_id) {
|
||||
existing
|
||||
} else {
|
||||
return Ok(None);
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
let addr = peer_state.last_addr;
|
||||
let key = peer_state.access_key.clone();
|
||||
match peer_state.machine.on_event(OstpEvent::Outbound(stream_id, payload))? {
|
||||
ProtocolAction::SendDatagram(frame) => {
|
||||
// Track outbound bytes per user
|
||||
track_user_bytes_down(&self.user_stats, &self.access_keys, &key, frame.len() as u64);
|
||||
Ok(Some((frame, addr)))
|
||||
let action = peer_state.machine.on_event(OstpEvent::Outbound(stream_id, payload))?;
|
||||
|
||||
let mut frames = Vec::new();
|
||||
let mut queue = vec![action];
|
||||
while let Some(current) = queue.pop() {
|
||||
match current {
|
||||
ProtocolAction::Multiple(list) => {
|
||||
for item in list {
|
||||
queue.push(item);
|
||||
}
|
||||
}
|
||||
ProtocolAction::SendDatagram(frame) => {
|
||||
track_user_bytes_down(&self.user_stats, &self.access_keys, &key, frame.len() as u64);
|
||||
frames.push((frame, addr));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
|
||||
Ok(frames)
|
||||
}
|
||||
|
||||
pub fn on_tick(&mut self) -> (Vec<(Bytes, SocketAddr)>, Vec<u32>) {
|
||||
|
|
@ -459,7 +477,7 @@ impl Dispatcher {
|
|||
let mut frames = Vec::new();
|
||||
let mut expired = Vec::new();
|
||||
let now = std::time::Instant::now();
|
||||
let timeout_dur = std::time::Duration::from_secs(600); // 10 minute session timeout (mobile NAT can be up to 5-10min)
|
||||
let timeout_dur = std::time::Duration::from_secs(600); // 10-minute session timeout (mobile NAT mappings can live 5–10 min)
|
||||
|
||||
// Gather expired or invalid sessions
|
||||
for (&sid, peer_state) in &self.peer_machines {
|
||||
|
|
@ -477,7 +495,7 @@ impl Dispatcher {
|
|||
let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&ps.access_key);
|
||||
let user_stats = self.get_or_create_user_stats(&ps.access_key);
|
||||
if now.duration_since(ps.last_seen) > timeout_dur {
|
||||
"inactive >5min"
|
||||
"inactive >10min"
|
||||
} else if !key_valid {
|
||||
"key deleted"
|
||||
} else if user_stats.is_over_limit() {
|
||||
|
|
|
|||
Loading…
Reference in New Issue