//! Congestion control for the OSTP protocol. //! //! Implements a simplified BBR-inspired algorithm that estimates bottleneck //! bandwidth and minimum RTT to determine the optimal sending rate. //! This replaces the fixed `retransmit_budget = 8` with an adaptive //! congestion window that responds to network conditions. use std::collections::VecDeque; use std::time::{Duration, Instant}; /// Congestion control state for a single OSTP session. pub struct CongestionController { /// Current congestion window in bytes (how much can be in-flight) cwnd: u64, /// Slow-start threshold in bytes ssthresh: u64, /// Current phase phase: Phase, /// Minimum RTT observed (used for BDP calculation) min_rtt: Duration, /// Maximum bandwidth observed (bytes/sec) max_bandwidth: u64, /// RTT samples for smoothing rtt_samples: VecDeque, /// Bandwidth samples bw_samples: VecDeque, /// Bytes currently in flight (unacknowledged) bytes_in_flight: u64, /// Total bytes acknowledged (for bandwidth estimation) total_acked: u64, /// Last time we received an ACK last_ack_time: Instant, /// Number of loss events in the current window loss_count: u32, /// Pacing rate: bytes per second pacing_rate: u64, /// MTU estimate (used for cwnd → packet count conversion) mtu: u64, /// Probe RTT phase timer probe_rtt_timer: Option, /// Min RTT expiry: re-probe after 10 seconds min_rtt_stamp: Instant, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Phase { /// Exponential growth until loss or ssthresh SlowStart, /// Probe bandwidth: cycle through pacing gains ProbeBandwidth, /// Periodically drain the queue to measure true min RTT ProbeRtt, } #[derive(Debug, Clone)] #[allow(dead_code)] struct RttSample { rtt: Duration, time: Instant, } #[derive(Debug, Clone)] #[allow(dead_code)] struct BwSample { bytes_per_sec: u64, time: Instant, } /// Maximum number of samples to keep for windowed min/max const MAX_SAMPLES: usize = 32; /// Initial congestion window: 10 packets × MTU const INITIAL_CWND_PACKETS: u64 = 10; /// Minimum cwnd: 2 packets const MIN_CWND_PACKETS: u64 = 2; /// Min RTT expiry window (after which we re-probe) const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10); /// ProbeRTT drain duration const PROBE_RTT_DURATION: Duration = Duration::from_millis(200); impl CongestionController { pub fn new(mtu: u64) -> Self { let now = Instant::now(); let initial_cwnd = INITIAL_CWND_PACKETS * mtu; Self { cwnd: initial_cwnd, ssthresh: u64::MAX, phase: Phase::SlowStart, min_rtt: Duration::from_millis(100), // Conservative initial estimate max_bandwidth: 0, rtt_samples: VecDeque::with_capacity(MAX_SAMPLES), bw_samples: VecDeque::with_capacity(MAX_SAMPLES), bytes_in_flight: 0, total_acked: 0, last_ack_time: now, loss_count: 0, pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec mtu, probe_rtt_timer: None, min_rtt_stamp: now, } } /// Returns the current congestion window in bytes. pub fn cwnd(&self) -> u64 { self.cwnd } /// Returns the current congestion window in packets. pub fn cwnd_packets(&self) -> usize { (self.cwnd / self.mtu).max(MIN_CWND_PACKETS) as usize } /// Returns the current pacing rate in bytes/sec. pub fn pacing_rate(&self) -> u64 { self.pacing_rate } /// Returns the smoothed RTT estimate. pub fn smoothed_rtt(&self) -> Duration { self.min_rtt } /// Returns how many bytes can still be sent. pub fn available_cwnd(&self) -> u64 { self.cwnd.saturating_sub(self.bytes_in_flight) } /// Returns the recommended retransmit budget per tick. pub fn retransmit_budget(&self) -> usize { // Allow retransmitting up to 1/4 of the cwnd in packets per tick let budget = (self.cwnd_packets() / 4).max(2); budget.min(64) // cap at 64 to prevent burst } /// Check whether we can send more data. pub fn can_send(&self) -> bool { self.bytes_in_flight < self.cwnd } /// Record that we sent `bytes` of data. pub fn on_send(&mut self, bytes: u64) { self.bytes_in_flight = self.bytes_in_flight.saturating_add(bytes); } /// Record that `bytes` were acknowledged with the given RTT sample. pub fn on_ack(&mut self, bytes: u64, rtt: Duration) { let now = Instant::now(); self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes); self.total_acked = self.total_acked.saturating_add(bytes); // Update RTT self.update_rtt(rtt, now); // Update bandwidth estimate self.update_bandwidth(bytes, now); // State machine match self.phase { Phase::SlowStart => { // Exponential growth: increase cwnd by acked bytes self.cwnd = self.cwnd.saturating_add(bytes); if self.cwnd >= self.ssthresh { self.phase = Phase::ProbeBandwidth; tracing::debug!(cwnd = self.cwnd, "congestion: exiting slow start"); } } Phase::ProbeBandwidth => { // BBR-style: target cwnd = BDP * gain let bdp = self.bandwidth_delay_product(); // Apply gain of 1.25 during probe bandwidth let target = (bdp * 5 / 4).max(MIN_CWND_PACKETS * self.mtu); // Smooth transition if self.cwnd < target { self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1)); } else { self.cwnd = target; } } Phase::ProbeRtt => { // Drain down to 4 packets to measure true min RTT self.cwnd = MIN_CWND_PACKETS * self.mtu * 2; if let Some(timer) = self.probe_rtt_timer { if now.duration_since(timer) >= PROBE_RTT_DURATION { // ProbeRTT complete, return to ProbeBandwidth self.phase = Phase::ProbeBandwidth; self.probe_rtt_timer = None; let bdp = self.bandwidth_delay_product(); self.cwnd = bdp.max(MIN_CWND_PACKETS * self.mtu); tracing::debug!(cwnd = self.cwnd, min_rtt = ?self.min_rtt, "congestion: probe RTT complete"); } } } } // Periodically enter ProbeRTT to refresh min_rtt if now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY && self.phase != Phase::ProbeRtt { self.phase = Phase::ProbeRtt; self.probe_rtt_timer = Some(now); tracing::debug!("congestion: entering probe RTT phase"); } self.update_pacing_rate(); self.last_ack_time = now; } /// Record a loss event. pub fn on_loss(&mut self, bytes_lost: u64) { self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes_lost); self.loss_count += 1; match self.phase { Phase::SlowStart => { // Exit slow start, set ssthresh to half of cwnd self.ssthresh = self.cwnd / 2; self.cwnd = self.ssthresh.max(MIN_CWND_PACKETS * self.mtu); self.phase = Phase::ProbeBandwidth; tracing::debug!(cwnd = self.cwnd, ssthresh = self.ssthresh, "congestion: loss during slow start"); } Phase::ProbeBandwidth => { // Multiplicative decrease: cwnd *= 0.7 (BBR-style, less aggressive than Cubic's 0.5) self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu); tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced"); } Phase::ProbeRtt => { // Don't react to loss during ProbeRTT } } self.update_pacing_rate(); } /// Called periodically to update state. pub fn on_tick(&mut self) { // Nothing special needed per-tick -- state updates happen on ACK/loss } // ── Private ────────────────────────────────────────────────────────────── fn update_rtt(&mut self, rtt: Duration, now: Instant) { // Track windowed minimum RTT if rtt < self.min_rtt || now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY { self.min_rtt = rtt; self.min_rtt_stamp = now; } // Keep sample history self.rtt_samples.push_back(RttSample { rtt, time: now }); while self.rtt_samples.len() > MAX_SAMPLES { self.rtt_samples.pop_front(); } } fn update_bandwidth(&mut self, acked_bytes: u64, now: Instant) { let elapsed = now.duration_since(self.last_ack_time); if elapsed.as_micros() > 0 { let bw = acked_bytes * 1_000_000 / elapsed.as_micros() as u64; if bw > self.max_bandwidth { self.max_bandwidth = bw; } self.bw_samples.push_back(BwSample { bytes_per_sec: bw, time: now }); while self.bw_samples.len() > MAX_SAMPLES { self.bw_samples.pop_front(); } } } fn bandwidth_delay_product(&self) -> u64 { // BDP = max_bandwidth * min_rtt let bw = if self.max_bandwidth > 0 { self.max_bandwidth } else { // Fallback: assume 10 Mbps 1_250_000 }; let rtt_secs = self.min_rtt.as_secs_f64(); (bw as f64 * rtt_secs) as u64 } fn update_pacing_rate(&mut self) { // Pacing rate = cwnd / min_rtt (with gain) let rtt_us = self.min_rtt.as_micros().max(1) as u64; self.pacing_rate = self.cwnd * 1_000_000 / rtt_us; } } #[cfg(test)] mod tests { use super::*; #[test] fn test_initial_state() { let cc = CongestionController::new(1200); assert_eq!(cc.cwnd(), 12000); // 10 * 1200 assert!(cc.can_send()); assert_eq!(cc.cwnd_packets(), 10); } #[test] fn test_slow_start_growth() { let mut cc = CongestionController::new(1200); // Simulate sending and ACKing cc.on_send(1200); cc.on_ack(1200, Duration::from_millis(50)); // cwnd should grow assert!(cc.cwnd() > 12000); } #[test] fn test_loss_reduces_cwnd() { let mut cc = CongestionController::new(1200); let initial = cc.cwnd(); cc.on_loss(1200); assert!(cc.cwnd() < initial); } #[test] fn test_can_send_limits() { let mut cc = CongestionController::new(1200); // Send until cwnd is exhausted for _ in 0..10 { cc.on_send(1200); } assert!(!cc.can_send()); // cwnd exhausted } #[test] fn test_retransmit_budget() { let cc = CongestionController::new(1200); let budget = cc.retransmit_budget(); assert!(budget >= 2); assert!(budget <= 64); } #[test] fn test_rtt_tracking() { let mut cc = CongestionController::new(1200); cc.on_send(1200); cc.on_ack(1200, Duration::from_millis(25)); assert_eq!(cc.smoothed_rtt(), Duration::from_millis(25)); } }