ostp/ostp-core/src/congestion.rs

338 lines
12 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Congestion control for the OSTP protocol.
//!
//! Implements a simplified BBR-inspired algorithm that estimates bottleneck
//! bandwidth and minimum RTT to determine the optimal sending rate.
//! This replaces the fixed `retransmit_budget = 8` with an adaptive
//! congestion window that responds to network conditions.
//!
//! RTO calculation follows RFC 6298:
//! SRTT = (1 - α) * SRTT + α * RTT (α = 1/8)
//! RTTVAR = (1 - β) * RTTVAR + β * |SRTT - RTT| (β = 1/4)
//! RTO = SRTT + 4 * RTTVAR
//! clamped to [RTO_MIN, RTO_MAX]
use std::time::{Duration, Instant};
/// Congestion control state for a single OSTP session.
pub struct CongestionController {
/// Current congestion window in bytes (how much can be in-flight)
cwnd: u64,
/// Slow-start threshold in bytes
ssthresh: u64,
/// Current phase
phase: Phase,
/// Minimum RTT observed (for BBR-style bandwidth estimation)
min_rtt: Duration,
/// Smoothed RTT (RFC 6298 SRTT)
srtt: Duration,
/// RTT variance (RFC 6298 RTTVAR)
rttvar: Duration,
/// Whether we have received a first RTT sample
rtt_initialized: bool,
/// Bytes currently in flight (unacknowledged)
bytes_in_flight: u64,
/// Total bytes acknowledged (for bandwidth estimation)
total_acked: u64,
/// Last time we received an ACK
last_ack_time: Instant,
/// Number of loss events in the current window
loss_count: u32,
/// Pacing rate: bytes per second
pacing_rate: u64,
/// MTU estimate (used for cwnd → packet count conversion)
mtu: u64,
/// Min RTT expiry: re-probe after 10 seconds
min_rtt_stamp: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
/// Exponential growth until loss or ssthresh
SlowStart,
/// Probe bandwidth: additive increase
ProbeBandwidth,
}
/// Initial congestion window: 32 packets × MTU (IW10 is too conservative for modern links)
const INITIAL_CWND_PACKETS: u64 = 32;
/// Minimum cwnd: 2 packets
const MIN_CWND_PACKETS: u64 = 2;
/// Min RTT expiry window (after which we re-probe)
const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10);
/// Minimum RTO (RFC 6298: 1s in TCP; we use 50ms since we own the protocol)
const RTO_MIN: Duration = Duration::from_millis(50);
/// Maximum RTO
const RTO_MAX: Duration = Duration::from_secs(16);
/// Initial RTT estimate — 30 ms is reasonable for a well-connected VPN server.
/// Will be replaced by first real measurement within milliseconds.
const INITIAL_RTT: Duration = Duration::from_millis(30);
impl CongestionController {
pub fn new(mtu: u64) -> Self {
let now = Instant::now();
let initial_cwnd = INITIAL_CWND_PACKETS * mtu;
// Initial pacing: deliver cwnd in ~2 RTTs to fill the pipe quickly
let initial_pacing = initial_cwnd * 1_000_000 / INITIAL_RTT.as_micros().max(1) as u64;
Self {
cwnd: initial_cwnd,
ssthresh: u64::MAX,
phase: Phase::SlowStart,
min_rtt: INITIAL_RTT,
srtt: INITIAL_RTT,
rttvar: INITIAL_RTT / 2,
rtt_initialized: false,
bytes_in_flight: 0,
total_acked: 0,
last_ack_time: now,
loss_count: 0,
pacing_rate: initial_pacing,
mtu,
min_rtt_stamp: now,
}
}
/// Returns the current congestion window in bytes.
pub fn cwnd(&self) -> u64 {
self.cwnd
}
/// Returns the current congestion window in packets.
pub fn cwnd_packets(&self) -> usize {
(self.cwnd / self.mtu).max(MIN_CWND_PACKETS) as usize
}
/// Returns the current pacing rate in bytes/sec.
pub fn pacing_rate(&self) -> u64 {
self.pacing_rate
}
/// Returns the smoothed RTT estimate (SRTT).
pub fn smoothed_rtt(&self) -> Duration {
self.srtt
}
/// Returns the adaptive RTO computed per RFC 6298:
/// RTO = SRTT + 4 * RTTVAR, clamped to [RTO_MIN, RTO_MAX].
///
/// This replaces the static `rto_ms` field in ProtocolMachine so that
/// retransmit timers automatically track changing network conditions.
pub fn rto(&self) -> Duration {
let rttvar4 = self.rttvar.saturating_mul(4);
let rto = self.srtt.saturating_add(rttvar4);
rto.clamp(RTO_MIN, RTO_MAX)
}
/// Returns how many bytes can still be sent.
pub fn available_cwnd(&self) -> u64 {
self.cwnd.saturating_sub(self.bytes_in_flight)
}
/// Returns the recommended retransmit budget per tick.
pub fn retransmit_budget(&self) -> usize {
// Allow retransmitting up to 1/4 of the cwnd in packets per tick
let budget = (self.cwnd_packets() / 4).max(2);
budget.min(64) // cap at 64 to prevent burst
}
/// Check whether we can send more data.
pub fn can_send(&self) -> bool {
self.bytes_in_flight < self.cwnd
}
/// Record that we sent `bytes` of data.
pub fn on_send(&mut self, bytes: u64) {
self.bytes_in_flight = self.bytes_in_flight.saturating_add(bytes);
}
/// Record that `bytes` were acknowledged with the given RTT sample.
pub fn on_ack(&mut self, bytes: u64, rtt: Duration) {
let now = Instant::now();
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes);
self.total_acked = self.total_acked.saturating_add(bytes);
// Update RTT measurements
self.update_rtt(rtt, now);
// State machine
match self.phase {
Phase::SlowStart => {
// Exponential growth: increase cwnd by acked bytes (doubles per RTT)
self.cwnd = self.cwnd.saturating_add(bytes);
if self.cwnd >= self.ssthresh {
self.phase = Phase::ProbeBandwidth;
tracing::debug!(cwnd = self.cwnd, "congestion: exiting slow start");
}
}
Phase::ProbeBandwidth => {
// TCP Reno Additive Increase: increase cwnd by ~1 MTU per RTT
self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1));
}
}
self.update_pacing_rate();
self.last_ack_time = now;
}
/// Record a loss event.
pub fn on_loss(&mut self, bytes_lost: u64) {
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes_lost);
self.loss_count += 1;
match self.phase {
Phase::SlowStart => {
// Exit slow start, set ssthresh to half of cwnd
self.ssthresh = self.cwnd / 2;
self.cwnd = self.ssthresh.max(MIN_CWND_PACKETS * self.mtu);
self.phase = Phase::ProbeBandwidth;
tracing::debug!(cwnd = self.cwnd, ssthresh = self.ssthresh, "congestion: loss during slow start");
}
Phase::ProbeBandwidth => {
// Multiplicative decrease: cwnd *= 0.7 (BBR-style, less aggressive than Cubic's 0.5)
self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu);
tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced");
}
}
self.update_pacing_rate();
}
// ── Private ──────────────────────────────────────────────────────────────
fn update_rtt(&mut self, rtt: Duration, now: Instant) {
// Update windowed minimum RTT (for pacing)
if rtt < self.min_rtt || now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY {
self.min_rtt = rtt;
self.min_rtt_stamp = now;
}
// Update SRTT and RTTVAR per RFC 6298
if !self.rtt_initialized {
// First measurement: initialize directly
self.srtt = rtt;
self.rttvar = rtt / 2;
self.rtt_initialized = true;
} else {
// RTTVAR = (3/4) * RTTVAR + (1/4) * |SRTT - R|
let diff = if rtt > self.srtt {
rtt - self.srtt
} else {
self.srtt - rtt
};
// Integer-safe: RTTVAR = RTTVAR - RTTVAR/4 + diff/4
self.rttvar = self.rttvar
.saturating_sub(self.rttvar / 4)
.saturating_add(diff / 4);
// SRTT = (7/8) * SRTT + (1/8) * R
self.srtt = self.srtt
.saturating_sub(self.srtt / 8)
.saturating_add(rtt / 8);
}
tracing::trace!(
srtt_ms = self.srtt.as_millis(),
rttvar_ms = self.rttvar.as_millis(),
rto_ms = self.rto().as_millis(),
"congestion: RTT updated"
);
}
fn update_pacing_rate(&mut self) {
// Pacing rate = cwnd / min_rtt (delivery rate target)
let rtt_us = self.min_rtt.as_micros().max(1) as u64;
self.pacing_rate = self.cwnd * 1_000_000 / rtt_us;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state() {
let cc = CongestionController::new(1200);
assert_eq!(cc.cwnd(), 32 * 1200); // 32 * 1200
assert!(cc.can_send());
assert_eq!(cc.cwnd_packets(), 32);
}
#[test]
fn test_slow_start_growth() {
let mut cc = CongestionController::new(1200);
let initial = cc.cwnd();
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(50));
assert!(cc.cwnd() > initial);
}
#[test]
fn test_loss_reduces_cwnd() {
let mut cc = CongestionController::new(1200);
let initial = cc.cwnd();
cc.on_loss(1200);
assert!(cc.cwnd() < initial);
}
#[test]
fn test_can_send_limits() {
let mut cc = CongestionController::new(1200);
// Send until cwnd is exhausted
for _ in 0..32 {
cc.on_send(1200);
}
assert!(!cc.can_send()); // cwnd exhausted
}
#[test]
fn test_retransmit_budget() {
let cc = CongestionController::new(1200);
let budget = cc.retransmit_budget();
assert!(budget >= 2);
assert!(budget <= 64);
}
#[test]
fn test_rtt_tracking_first_sample() {
let mut cc = CongestionController::new(1200);
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(25));
// After first sample: SRTT = 25ms, RTTVAR = 12ms
assert_eq!(cc.smoothed_rtt(), Duration::from_millis(25));
}
#[test]
fn test_rto_rfc6298() {
let mut cc = CongestionController::new(1200);
// After first sample with RTT=50ms: SRTT=50ms, RTTVAR=25ms, RTO=150ms
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(50));
let rto = cc.rto();
// RTO = 50 + 4*25 = 150ms; clamped to [50ms, 16s]
assert!(rto >= RTO_MIN);
assert!(rto <= RTO_MAX);
assert_eq!(rto, Duration::from_millis(150));
}
#[test]
fn test_rto_clamp_min() {
let cc = CongestionController::new(1200);
// Even with no RTT samples, RTO should not go below RTO_MIN
assert!(cc.rto() >= RTO_MIN);
}
#[test]
fn test_rto_adapts_after_multiple_samples() {
let mut cc = CongestionController::new(1200);
// Feed several consistent RTT samples
for _ in 0..8 {
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(20));
}
// After convergence, RTTVAR should be small → RTO close to SRTT + small margin
let rto = cc.rto();
// Should be well below 100ms (the old hardcoded default)
assert!(rto < Duration::from_millis(200));
assert!(rto >= RTO_MIN);
}
}