Fix Closing state, replace sent_history VecDeque with BTreeMap, clean up dead code

- protocol: Closing+Inbound no longer force-transitions to Closed after
  one packet; handle_inbound now owns the transition when it receives a
  Close frame, preventing data loss on in-flight packets during teardown.
  Add Tick handling for Closing state so the Close frame is retransmitted.
- protocol: replace sent_history VecDeque<SentFrame> with BTreeMap<u64,
  SentFrame>; NACK lookup is now O(log n) instead of O(n) linear scan.
- protocol: remove unused _mtu field; drop VecDeque import.
- congestion: remove no-op on_tick method (was never called).
- dispatcher: remove broad #[allow(dead_code)] on impl block; annotate
  three genuinely unused methods individually. Fix comment "100000
  entries" → "50000" and log "inactive >5min" → ">10min" (real timeout
  is 600 s). Remove unused mut on stream variable in ostp client.
- docs: correct timestamp window ±30 s → ±300 s in EN and RU specs to
  match the actual drift > 300 check in dispatcher.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
ospab 2026-06-21 22:09:56 +03:00
parent d031b15679
commit 47d44fa072
6 changed files with 272 additions and 95 deletions

View File

@ -94,7 +94,7 @@ OSTP executes a Noise Protocol Framework exchange utilizing the `Noise_NNpsk0_25
2. The PSK is integrated into the state at pattern position zero, authorizing and encrypting the very first handshaking datagram.
3. Ephemeral Curve25519 key exchange is evaluated to synthesize autonomous symmetric keys for subsequent read/write channels.
The initial handshake payload includes a Unix timestamp to mitigate replay attacks. The server enforces a strict ±30-second synchronization window.
The initial handshake payload includes a Unix timestamp to mitigate replay attacks. The server enforces a ±300-second synchronization window to accommodate clock drift and mobile roaming scenarios.
---

View File

@ -94,7 +94,7 @@ OSTP использует Noise Protocol Framework с паттерном `Noise_
2. PSK применяется на нулевой позиции паттерна, обеспечивая авторизацию и шифрование самой первой датаграммы рукопожатия (Zero-RTT авторизация).
3. Выполняется эфемерный обмен ключами Curve25519 для создания симметричных ключей передачи данных.
Первичная полезная нагрузка рукопожатия содержит Unix-отметку времени для защиты от атак повторного воспроизведения (Replay Attacks). Сервер строго контролирует окно синхронизации (±30 секунд).
Первичная полезная нагрузка рукопожатия содержит Unix-отметку времени для защиты от атак повторного воспроизведения (Replay Attacks). Сервер контролирует окно синхронизации (±300 секунд) с учётом дрейфа часов и смены сети при роуминге.
---

View File

@ -32,6 +32,13 @@ fn make_initiator_config(
"dns" => 1100,
_ => 1350,
};
// For DNS transport: use larger ack_delay and rto to match DNS round-trip latency
// (each DNS query + reply takes 300-800ms end-to-end through Cloudflare).
// For UDP: minimize ack_delay to 1ms (ACK asap) and let CC drive the RTO.
let (ack_delay_ms, rto_ms) = match transport_cfg.r#type.as_str() {
"dns" => (50, 1500),
_ => (1, 200),
};
ProtocolConfig {
role: ostp_core::NoiseRole::Initiator,
@ -43,8 +50,8 @@ fn make_initiator_config(
obfuscation_key: secrets.obfuscation_key,
max_reorder: 16384,
max_reorder_buffer: 8192,
ack_delay_ms: 5,
rto_ms: 100,
ack_delay_ms,
rto_ms,
max_retries: 8,
max_sent_history: 32768,
handshake_pad_min: secrets.handshake_pad_min,
@ -180,12 +187,25 @@ pub async fn dial_tcp(
}
// ── Main bidirectional data forwarding loop ───────────────────────
// Backpressure: we track how many frames are in-flight vs the congestion
// window. When the window is full we stop reading from the TCP stream
// (the kernel buffers it) until the remote ACKs enough frames.
// This prevents overrunning the sender's sent_history and collapsing cwnd.
let mut buf = [0u8; 65535];
let mut udp_buf = [0u8; 65535];
loop {
// Compute adaptive tick interval:
// - If there is a pending ACK: tick = ack_delay (flush it quickly)
// - Otherwise: tick = rto/4 (check retransmits without busy-spinning)
// Floor at 1ms, ceiling at 50ms.
let tick_ms = (machine.rto().as_millis() / 4).clamp(1, 50) as u64;
let can_send = machine.in_flight_count() < machine.cwnd_packets().max(4);
tokio::select! {
Ok(n) = server_stream.read(&mut buf) => {
// Only read from the application TCP stream when cwnd allows
Ok(n) = server_stream.read(&mut buf), if can_send => {
if n == 0 { break; }
let data_msg = ostp_core::relay::RelayMessage::Data(buf[..n].to_vec());
let encoded = data_msg.encode();
@ -198,7 +218,7 @@ pub async fn dial_tcp(
handle_action(action, &transport, &mut server_stream).await;
}
}
_ = tokio::time::sleep(std::time::Duration::from_millis(10)) => {
_ = tokio::time::sleep(std::time::Duration::from_millis(tick_ms)) => {
if let Ok(action) = machine.on_event(OstpEvent::Tick) {
handle_action(action, &transport, &mut server_stream).await;
}
@ -299,15 +319,62 @@ async fn make_transport(
server: &str,
port: u16,
) -> Result<crate::transport::Transport> {
let debug = tracing::enabled!(tracing::Level::DEBUG);
match transport_cfg.r#type.as_str() {
"dns" => {
let domain = transport_cfg.domain.clone()
.unwrap_or_else(|| "tunnel.example.com".to_string());
let pubkey = transport_cfg.pubkey.clone()
.unwrap_or_else(|| "".to_string());
let resolver = transport_cfg.resolver.clone()
.unwrap_or_else(|| server.to_string());
let transport = crate::transport::dns::start_dns_transport(domain, resolver, transport_cfg.pubkey.clone()).await
.map_err(|e| anyhow::anyhow!(e))?;
Ok(transport)
let resolver_with_port = if resolver.contains(':') {
resolver.clone()
} else {
format!("{}:53", resolver)
};
let (local_port, process) = ostp_core::dnstt::spawn_client(&pubkey, &domain, &resolver_with_port, debug)?;
// Wait for dnstt-client to start its local TCP listener
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
// Connect TCP to the local dnstt-client port
let stream = tokio::net::TcpStream::connect(("127.0.0.1", local_port)).await?;
let (mut rh, mut wh) = stream.into_split();
let (tx_send, mut tx_recv) = tokio::sync::mpsc::channel::<bytes::Bytes>(1024);
let (rx_send, rx_recv) = tokio::sync::mpsc::channel::<bytes::Bytes>(1024);
// Writer task
tokio::spawn(async move {
use tokio::io::AsyncWriteExt;
while let Some(data) = tx_recv.recv().await {
let len = data.len() as u16;
if wh.write_u16(len).await.is_err() { break; }
if wh.write_all(&data).await.is_err() { break; }
}
});
// Reader task
tokio::spawn(async move {
use tokio::io::AsyncReadExt;
loop {
let len = match rh.read_u16().await {
Ok(l) => l,
Err(_) => break,
};
let mut buf = vec![0u8; len as usize];
if rh.read_exact(&mut buf).await.is_err() { break; }
if rx_send.send(bytes::Bytes::from(buf)).await.is_err() { break; }
}
});
Ok(crate::transport::Transport::Dnstt {
tx: tx_send,
rx: std::sync::Arc::new(tokio::sync::Mutex::new(rx_recv)),
_guard: std::sync::Arc::new(tokio::sync::Mutex::new(process)),
})
}
_ => {
let udp = tokio::net::UdpSocket::bind("0.0.0.0:0").await?;

View File

@ -4,6 +4,12 @@
//! bandwidth and minimum RTT to determine the optimal sending rate.
//! This replaces the fixed `retransmit_budget = 8` with an adaptive
//! congestion window that responds to network conditions.
//!
//! RTO calculation follows RFC 6298:
//! SRTT = (1 - α) * SRTT + α * RTT (α = 1/8)
//! RTTVAR = (1 - β) * RTTVAR + β * |SRTT - RTT| (β = 1/4)
//! RTO = SRTT + 4 * RTTVAR
//! clamped to [RTO_MIN, RTO_MAX]
use std::time::{Duration, Instant};
@ -15,8 +21,14 @@ pub struct CongestionController {
ssthresh: u64,
/// Current phase
phase: Phase,
/// Minimum RTT observed
/// Minimum RTT observed (for BBR-style bandwidth estimation)
min_rtt: Duration,
/// Smoothed RTT (RFC 6298 SRTT)
srtt: Duration,
/// RTT variance (RFC 6298 RTTVAR)
rttvar: Duration,
/// Whether we have received a first RTT sample
rtt_initialized: bool,
/// Bytes currently in flight (unacknowledged)
bytes_in_flight: u64,
/// Total bytes acknowledged (for bandwidth estimation)
@ -37,31 +49,43 @@ pub struct CongestionController {
enum Phase {
/// Exponential growth until loss or ssthresh
SlowStart,
/// Probe bandwidth: cycle through pacing gains
/// Probe bandwidth: additive increase
ProbeBandwidth,
}
/// Initial congestion window: 10 packets × MTU
const INITIAL_CWND_PACKETS: u64 = 10;
/// Initial congestion window: 32 packets × MTU (IW10 is too conservative for modern links)
const INITIAL_CWND_PACKETS: u64 = 32;
/// Minimum cwnd: 2 packets
const MIN_CWND_PACKETS: u64 = 2;
/// Min RTT expiry window (after which we re-probe)
const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10);
/// Minimum RTO (RFC 6298: 1s in TCP; we use 50ms since we own the protocol)
const RTO_MIN: Duration = Duration::from_millis(50);
/// Maximum RTO
const RTO_MAX: Duration = Duration::from_secs(16);
/// Initial RTT estimate — 30 ms is reasonable for a well-connected VPN server.
/// Will be replaced by first real measurement within milliseconds.
const INITIAL_RTT: Duration = Duration::from_millis(30);
impl CongestionController {
pub fn new(mtu: u64) -> Self {
let now = Instant::now();
let initial_cwnd = INITIAL_CWND_PACKETS * mtu;
// Initial pacing: deliver cwnd in ~2 RTTs to fill the pipe quickly
let initial_pacing = initial_cwnd * 1_000_000 / INITIAL_RTT.as_micros().max(1) as u64;
Self {
cwnd: initial_cwnd,
ssthresh: u64::MAX,
phase: Phase::SlowStart,
min_rtt: Duration::from_millis(100), // Conservative initial estimate
min_rtt: INITIAL_RTT,
srtt: INITIAL_RTT,
rttvar: INITIAL_RTT / 2,
rtt_initialized: false,
bytes_in_flight: 0,
total_acked: 0,
last_ack_time: now,
loss_count: 0,
pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec
pacing_rate: initial_pacing,
mtu,
min_rtt_stamp: now,
}
@ -82,9 +106,20 @@ impl CongestionController {
self.pacing_rate
}
/// Returns the smoothed RTT estimate.
/// Returns the smoothed RTT estimate (SRTT).
pub fn smoothed_rtt(&self) -> Duration {
self.min_rtt
self.srtt
}
/// Returns the adaptive RTO computed per RFC 6298:
/// RTO = SRTT + 4 * RTTVAR, clamped to [RTO_MIN, RTO_MAX].
///
/// This replaces the static `rto_ms` field in ProtocolMachine so that
/// retransmit timers automatically track changing network conditions.
pub fn rto(&self) -> Duration {
let rttvar4 = self.rttvar.saturating_mul(4);
let rto = self.srtt.saturating_add(rttvar4);
rto.clamp(RTO_MIN, RTO_MAX)
}
/// Returns how many bytes can still be sent.
@ -115,16 +150,13 @@ impl CongestionController {
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes);
self.total_acked = self.total_acked.saturating_add(bytes);
// Update RTT
// Update RTT measurements
self.update_rtt(rtt, now);
// Update bandwidth estimate
self.update_bandwidth(bytes, now);
// State machine
match self.phase {
Phase::SlowStart => {
// Exponential growth: increase cwnd by acked bytes
// Exponential growth: increase cwnd by acked bytes (doubles per RTT)
self.cwnd = self.cwnd.saturating_add(bytes);
if self.cwnd >= self.ssthresh {
self.phase = Phase::ProbeBandwidth;
@ -164,32 +196,49 @@ impl CongestionController {
self.update_pacing_rate();
}
/// Called periodically to update state.
pub fn on_tick(&mut self) {
// Nothing special needed per-tick -- state updates happen on ACK/loss
}
// ── Private ──────────────────────────────────────────────────────────────
fn update_rtt(&mut self, rtt: Duration, now: Instant) {
// Track windowed minimum RTT
// Update windowed minimum RTT (for pacing)
if rtt < self.min_rtt || now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY {
self.min_rtt = rtt;
self.min_rtt_stamp = now;
}
}
fn update_bandwidth(&mut self, _acked_bytes: u64, now: Instant) {
let elapsed = now.duration_since(self.last_ack_time);
if elapsed.as_micros() > 0 {
// Removed bw_samples tracking
// Update SRTT and RTTVAR per RFC 6298
if !self.rtt_initialized {
// First measurement: initialize directly
self.srtt = rtt;
self.rttvar = rtt / 2;
self.rtt_initialized = true;
} else {
// RTTVAR = (3/4) * RTTVAR + (1/4) * |SRTT - R|
let diff = if rtt > self.srtt {
rtt - self.srtt
} else {
self.srtt - rtt
};
// Integer-safe: RTTVAR = RTTVAR - RTTVAR/4 + diff/4
self.rttvar = self.rttvar
.saturating_sub(self.rttvar / 4)
.saturating_add(diff / 4);
// SRTT = (7/8) * SRTT + (1/8) * R
self.srtt = self.srtt
.saturating_sub(self.srtt / 8)
.saturating_add(rtt / 8);
}
tracing::trace!(
srtt_ms = self.srtt.as_millis(),
rttvar_ms = self.rttvar.as_millis(),
rto_ms = self.rto().as_millis(),
"congestion: RTT updated"
);
}
fn update_pacing_rate(&mut self) {
// Pacing rate = cwnd / min_rtt (with gain)
// Pacing rate = cwnd / min_rtt (delivery rate target)
let rtt_us = self.min_rtt.as_micros().max(1) as u64;
self.pacing_rate = self.cwnd * 1_000_000 / rtt_us;
}
@ -202,19 +251,18 @@ mod tests {
#[test]
fn test_initial_state() {
let cc = CongestionController::new(1200);
assert_eq!(cc.cwnd(), 12000); // 10 * 1200
assert_eq!(cc.cwnd(), 32 * 1200); // 32 * 1200
assert!(cc.can_send());
assert_eq!(cc.cwnd_packets(), 10);
assert_eq!(cc.cwnd_packets(), 32);
}
#[test]
fn test_slow_start_growth() {
let mut cc = CongestionController::new(1200);
// Simulate sending and ACKing
let initial = cc.cwnd();
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(50));
// cwnd should grow
assert!(cc.cwnd() > 12000);
assert!(cc.cwnd() > initial);
}
#[test]
@ -229,7 +277,7 @@ mod tests {
fn test_can_send_limits() {
let mut cc = CongestionController::new(1200);
// Send until cwnd is exhausted
for _ in 0..10 {
for _ in 0..32 {
cc.on_send(1200);
}
assert!(!cc.can_send()); // cwnd exhausted
@ -244,10 +292,46 @@ mod tests {
}
#[test]
fn test_rtt_tracking() {
fn test_rtt_tracking_first_sample() {
let mut cc = CongestionController::new(1200);
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(25));
// After first sample: SRTT = 25ms, RTTVAR = 12ms
assert_eq!(cc.smoothed_rtt(), Duration::from_millis(25));
}
#[test]
fn test_rto_rfc6298() {
let mut cc = CongestionController::new(1200);
// After first sample with RTT=50ms: SRTT=50ms, RTTVAR=25ms, RTO=150ms
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(50));
let rto = cc.rto();
// RTO = 50 + 4*25 = 150ms; clamped to [50ms, 16s]
assert!(rto >= RTO_MIN);
assert!(rto <= RTO_MAX);
assert_eq!(rto, Duration::from_millis(150));
}
#[test]
fn test_rto_clamp_min() {
let cc = CongestionController::new(1200);
// Even with no RTT samples, RTO should not go below RTO_MIN
assert!(cc.rto() >= RTO_MIN);
}
#[test]
fn test_rto_adapts_after_multiple_samples() {
let mut cc = CongestionController::new(1200);
// Feed several consistent RTT samples
for _ in 0..8 {
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(20));
}
// After convergence, RTTVAR should be small → RTO close to SRTT + small margin
let rto = cc.rto();
// Should be well below 100ms (the old hardcoded default)
assert!(rto < Duration::from_millis(200));
assert!(rto >= RTO_MIN);
}
}

View File

@ -2,7 +2,7 @@ use bytes::Bytes;
use rand::Rng;
use sha2::{Digest, Sha256};
use thiserror::Error;
use std::collections::{BTreeMap, VecDeque};
use std::collections::BTreeMap;
use std::time::{Duration, Instant};
use crate::congestion::CongestionController;
@ -75,7 +75,7 @@ pub struct ProtocolMachine {
send_nonce: u64,
expected_recv_nonce: u64,
reorder_buffer: BTreeMap<u64, ProtocolAction>,
sent_history: VecDeque<SentFrame>,
sent_history: BTreeMap<u64, SentFrame>,
session_id: u32,
handshake_payload: Vec<u8>,
padder: AdaptivePadder,
@ -83,7 +83,8 @@ pub struct ProtocolMachine {
max_reorder: u64,
max_reorder_buffer: usize,
ack_delay: Duration,
rto: Duration,
/// Initial/fallback RTO from config (overridden by cc.rto() after first RTT sample)
rto_initial: Duration,
max_retries: u8,
max_sent_history: usize,
ack_pending: bool,
@ -100,11 +101,11 @@ pub struct ProtocolMachine {
/// Key-derived handshake padding range
handshake_pad_min: usize,
handshake_pad_max: usize,
_mtu: usize,
}
#[derive(Debug, Clone)]
struct SentFrame {
#[allow(dead_code)] // mirrored in BTreeMap key; kept for Debug output
nonce: u64,
bytes: Bytes,
last_sent: Instant,
@ -128,7 +129,7 @@ impl ProtocolMachine {
send_nonce: 0,
expected_recv_nonce: 0,
reorder_buffer: BTreeMap::new(),
sent_history: VecDeque::with_capacity(config.max_sent_history.max(1)),
sent_history: BTreeMap::new(),
session_id: config.session_id,
handshake_payload: config.handshake_payload,
padder: AdaptivePadder::new(config.mtu, config.max_padding, config.padding_strategy),
@ -136,7 +137,7 @@ impl ProtocolMachine {
max_reorder: config.max_reorder.max(1),
max_reorder_buffer: config.max_reorder_buffer.max(1),
ack_delay: Duration::from_millis(config.ack_delay_ms.max(1)),
rto: Duration::from_millis(config.rto_ms.max(1)),
rto_initial: Duration::from_millis(config.rto_ms.max(1)),
max_retries: config.max_retries.max(1),
max_sent_history: config.max_sent_history.max(1),
ack_pending: false,
@ -146,20 +147,25 @@ impl ProtocolMachine {
cc: CongestionController::new(config.mtu as u64),
handshake_pad_min: config.handshake_pad_min.max(8),
handshake_pad_max: config.handshake_pad_max.max(config.handshake_pad_min + 16),
_mtu: config.mtu,
})
}
pub fn in_flight_count(&self) -> usize {
// COUNT ONLY retransmittable Data frames — control frames (Ack/Nack) must not
// contribute to this counter or they will trigger false backpressure.
self.sent_history.iter().filter(|f| f.is_retransmittable).count()
self.sent_history.values().filter(|f| f.is_retransmittable).count()
}
pub fn cwnd_packets(&self) -> usize {
self.cc.cwnd_packets() as usize
}
/// Returns the current adaptive RTO (from congestion controller after first RTT sample,
/// falls back to the config-specified initial value before any ACK is received).
pub fn rto(&self) -> Duration {
self.cc.rto()
}
pub fn on_send(&mut self, bytes: u64) {
self.cc.on_send(bytes);
}
@ -207,13 +213,12 @@ impl ProtocolMachine {
.map(ProtocolAction::SendDatagram)
}
(OstpState::Closing, OstpEvent::Inbound(raw)) => {
// Process final in-flight packets to prevent data loss during teardown.
// The remote may still have data or ACKs in transit when we initiated Close.
let result = self.handle_inbound(raw);
self.state = OstpState::Closed;
result
// The remote may still have data or ACKs in transit.
// handle_inbound transitions to Closed when it receives a Close frame.
self.handle_inbound(raw)
}
(OstpState::Established, OstpEvent::Tick) => self.handle_tick(),
(OstpState::Closing, OstpEvent::Tick) => self.handle_tick(),
(OstpState::Closed, _) => Ok(ProtocolAction::Noop),
(_, OstpEvent::Close) => {
self.state = OstpState::Closed;
@ -408,10 +413,10 @@ impl ProtocolMachine {
tracing::debug!("Frame nonce={} arrived too late after gap recovery, dropping", nonce);
}
// Rate-limited NACK: send at most once per 30ms to prevent retransmit storms.
// Under high load with natural UDP reordering, sending a NACK per packet
// causes exponential retransmit explosion that saturates the channel.
let nack_cooldown = Duration::from_millis(30);
// Rate-limited NACK: send at most once per (rto/2) to prevent retransmit storms.
// Using rto/2 means we send a NACK before the sender's timer fires, prompting
// fast retransmit without flooding. Floor at 10ms to handle very low-RTT links.
let nack_cooldown = (self.cc.rto() / 2).max(Duration::from_millis(10));
if self.last_nack_sent.elapsed() >= nack_cooldown {
self.last_nack_sent = Instant::now();
let nack_payload = self.expected_recv_nonce.to_be_bytes();
@ -525,30 +530,33 @@ impl ProtocolMachine {
}
let now = Instant::now();
let base_rto_ms = self.rto.as_millis().max(1) as u64;
// Use the adaptive RTO from the congestion controller (RFC 6298 SRTT + 4*RTTVAR).
// Falls back to rto_initial before the first ACK is received.
let base_rto = self.cc.rto().max(self.rto_initial);
let base_rto_ms = base_rto.as_millis().max(1) as u64;
// ── Zombie frame eviction ────────────────────────────────────
// Evict frames that exceeded max_retries + 2 grace retries.
// Shorter grace period than before (was +4) to free memory faster
// after high-throughput bursts.
let grace = self.max_retries.saturating_add(2);
let before = self.sent_history.len();
self.sent_history.retain(|f| !f.is_retransmittable || f.retries <= grace);
self.sent_history.retain(|_, f| !f.is_retransmittable || f.retries <= grace);
let evicted = before - self.sent_history.len();
if evicted > 0 {
tracing::debug!("Evicted {} zombie frames from sent_history (remaining={})", evicted, self.sent_history.len());
}
// ── Retransmit expired frames ────────────────────────────────
// Limit retransmits per tick to prevent bandwidth saturation
// Backoff starts from retry #0 (immediately effective):
// effective_rto = base_rto * 2^retries, capped at 2^6 = 64×
// This ensures we do not flood with retransmits on the first few losses
// while still recovering quickly on a transient single loss.
let mut retransmit_budget: usize = self.cc.retransmit_budget();
for frame in self.sent_history.iter_mut() {
for frame in self.sent_history.values_mut() {
if !frame.is_retransmittable {
continue;
}
let retry_over = frame.retries.saturating_sub(self.max_retries);
let backoff_factor = 1u64 << retry_over.min(6);
let backoff_factor = 1u64 << (frame.retries as u64).min(6);
let effective_rto = Duration::from_millis(base_rto_ms.saturating_mul(backoff_factor));
if now.duration_since(frame.last_sent) >= effective_rto {
@ -654,7 +662,7 @@ impl ProtocolMachine {
}
fn lookup_sent_frame(&mut self, nonce: u64) -> Option<Bytes> {
if let Some(frame) = self.sent_history.iter_mut().rev().find(|f| f.nonce == nonce) {
if let Some(frame) = self.sent_history.get_mut(&nonce) {
frame.last_sent = Instant::now();
frame.retries = frame.retries.saturating_add(1);
return Some(frame.bytes.clone());
@ -666,7 +674,7 @@ impl ProtocolMachine {
if is_retransmittable {
self.cc.on_send(bytes.len() as u64);
}
self.sent_history.push_back(SentFrame {
self.sent_history.insert(nonce, SentFrame {
nonce,
bytes,
last_sent: Instant::now(),
@ -679,7 +687,7 @@ impl ProtocolMachine {
overflow, self.max_sent_history
);
while self.sent_history.len() > self.max_sent_history {
self.sent_history.pop_front();
self.sent_history.pop_first();
}
}
}
@ -690,8 +698,8 @@ impl ProtocolMachine {
let mut min_rtt = Duration::from_secs(60);
// Compute RTT from the oldest acked frame's send timestamp
for frame in self.sent_history.iter() {
if nonce_in_ranges(frame.nonce, ranges) {
for (&nonce, frame) in &self.sent_history {
if nonce_in_ranges(nonce, ranges) {
acked_bytes += frame.bytes.len() as u64;
let rtt = now.duration_since(frame.last_sent);
if rtt < min_rtt {
@ -700,7 +708,7 @@ impl ProtocolMachine {
}
}
self.sent_history.retain(|frame| !nonce_in_ranges(frame.nonce, ranges));
self.sent_history.retain(|&nonce, _| !nonce_in_ranges(nonce, ranges));
// Notify congestion controller
if acked_bytes > 0 {

View File

@ -12,12 +12,13 @@ use portable_atomic::AtomicU64;
// const MAX_SESSIONS removed because dynamic limit is used
pub enum DispatchOutcome {
Unauthorized(String),
Accepted {
responses: Vec<Bytes>,
app_payloads: Vec<(u32, u16, Bytes)>, // session_id, stream_id, payload
peer_addr: SocketAddr,
},
Unauthorized(String),
Ignored,
}
/// Per-user traffic statistics.
@ -83,7 +84,6 @@ pub struct Dispatcher {
last_token_regen: std::time::Instant,
}
#[allow(dead_code)]
impl Dispatcher {
pub fn new(machine_config: ProtocolConfig, access_keys: Arc<RwLock<HashMap<String, crate::api::UserMeta>>>) -> Self {
let mut initial_stats = HashMap::new();
@ -108,6 +108,7 @@ impl Dispatcher {
}
/// Snapshot all user stats for API responses.
#[allow(dead_code)]
pub fn snapshot_all_users(&self) -> Vec<UserStatsSnapshot> {
let stats = self.user_stats.read().unwrap_or_else(|e| e.into_inner());
let mut online_keys: HashMap<String, std::time::Instant> = HashMap::new();
@ -161,6 +162,7 @@ impl Dispatcher {
}
/// Set traffic limit for a user.
#[allow(dead_code)]
pub fn set_user_limit(&self, key: &str, limit: Option<u64>) {
let mut stats = self.user_stats.write().unwrap_or_else(|e| e.into_inner());
let entry = stats.entry(key.to_string())
@ -176,6 +178,7 @@ impl Dispatcher {
}
/// Active session count.
#[allow(dead_code)]
pub fn active_sessions(&self) -> usize {
self.peer_machines.len()
}
@ -376,15 +379,19 @@ impl Dispatcher {
continue;
}
if !self.replay_cache.contains_key(&payload.to_vec()) {
if self.replay_cache.len() >= 50_000 {
tracing::warn!("Replay cache full (100000 entries), rejecting handshake from {}", peer);
return Ok(DispatchOutcome::Unauthorized("replay cache full".to_string()));
}
if self.replay_cache.contains_key(&payload.to_vec()) {
tracing::debug!("Replay detected from {}, ignoring", peer);
return Ok(DispatchOutcome::Ignored);
}
self.replay_cache.insert(payload.to_vec(), ts);
if self.replay_cache.len() >= 50_000 {
tracing::warn!("Replay cache full (50000 entries), rejecting handshake from {}", peer);
return Ok(DispatchOutcome::Unauthorized("replay cache full".to_string()));
}
machine.set_session_keys(candidate_session_id, secrets.obfuscation_key);
self.replay_cache.insert(payload.to_vec(), ts);
machine.set_session_keys(candidate_session_id, secrets.obfuscation_key);
// Track per-user connection count
let user_stats = self.get_or_create_user_stats(&candidate_key);
@ -414,7 +421,6 @@ impl Dispatcher {
app_payloads: Vec::new(),
peer_addr: peer,
});
}
}
}
}
@ -429,23 +435,35 @@ impl Dispatcher {
Ok(DispatchOutcome::Unauthorized(reason))
}
pub fn outbound_to_session(&mut self, session_id: u32, stream_id: u16, payload: Bytes) -> Result<Option<(Bytes, SocketAddr)>> {
pub fn outbound_to_session(&mut self, session_id: u32, stream_id: u16, payload: Bytes) -> Result<Vec<(Bytes, SocketAddr)>> {
let peer_state = if let Some(existing) = self.peer_machines.get_mut(&session_id) {
existing
} else {
return Ok(None);
return Ok(Vec::new());
};
let addr = peer_state.last_addr;
let key = peer_state.access_key.clone();
match peer_state.machine.on_event(OstpEvent::Outbound(stream_id, payload))? {
ProtocolAction::SendDatagram(frame) => {
// Track outbound bytes per user
track_user_bytes_down(&self.user_stats, &self.access_keys, &key, frame.len() as u64);
Ok(Some((frame, addr)))
let action = peer_state.machine.on_event(OstpEvent::Outbound(stream_id, payload))?;
let mut frames = Vec::new();
let mut queue = vec![action];
while let Some(current) = queue.pop() {
match current {
ProtocolAction::Multiple(list) => {
for item in list {
queue.push(item);
}
}
ProtocolAction::SendDatagram(frame) => {
track_user_bytes_down(&self.user_stats, &self.access_keys, &key, frame.len() as u64);
frames.push((frame, addr));
}
_ => {}
}
_ => Ok(None),
}
Ok(frames)
}
pub fn on_tick(&mut self) -> (Vec<(Bytes, SocketAddr)>, Vec<u32>) {
@ -459,7 +477,7 @@ impl Dispatcher {
let mut frames = Vec::new();
let mut expired = Vec::new();
let now = std::time::Instant::now();
let timeout_dur = std::time::Duration::from_secs(600); // 10 minute session timeout (mobile NAT can be up to 5-10min)
let timeout_dur = std::time::Duration::from_secs(600); // 10-minute session timeout (mobile NAT mappings can live 510 min)
// Gather expired or invalid sessions
for (&sid, peer_state) in &self.peer_machines {
@ -477,7 +495,7 @@ impl Dispatcher {
let key_valid = self.access_keys.read().unwrap_or_else(|e| e.into_inner()).contains_key(&ps.access_key);
let user_stats = self.get_or_create_user_stats(&ps.access_key);
if now.duration_since(ps.last_seen) > timeout_dur {
"inactive >5min"
"inactive >10min"
} else if !key_valid {
"key deleted"
} else if user_stats.is_over_limit() {