ostp/ostp-core/src/congestion.rs

335 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Congestion control for the OSTP protocol.
//!
//! Implements a simplified BBR-inspired algorithm that estimates bottleneck
//! bandwidth and minimum RTT to determine the optimal sending rate.
//! This replaces the fixed `retransmit_budget = 8` with an adaptive
//! congestion window that responds to network conditions.
use std::collections::VecDeque;
use std::time::{Duration, Instant};
/// Congestion control state for a single OSTP session.
pub struct CongestionController {
/// Current congestion window in bytes (how much can be in-flight)
cwnd: u64,
/// Slow-start threshold in bytes
ssthresh: u64,
/// Current phase
phase: Phase,
/// Minimum RTT observed (used for BDP calculation)
min_rtt: Duration,
/// Maximum bandwidth observed (bytes/sec)
max_bandwidth: u64,
/// RTT samples for smoothing
rtt_samples: VecDeque<RttSample>,
/// Bandwidth samples
bw_samples: VecDeque<BwSample>,
/// Bytes currently in flight (unacknowledged)
bytes_in_flight: u64,
/// Total bytes acknowledged (for bandwidth estimation)
total_acked: u64,
/// Last time we received an ACK
last_ack_time: Instant,
/// Number of loss events in the current window
loss_count: u32,
/// Pacing rate: bytes per second
pacing_rate: u64,
/// MTU estimate (used for cwnd → packet count conversion)
mtu: u64,
/// Probe RTT phase timer
probe_rtt_timer: Option<Instant>,
/// Min RTT expiry: re-probe after 10 seconds
min_rtt_stamp: Instant,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
/// Exponential growth until loss or ssthresh
SlowStart,
/// Probe bandwidth: cycle through pacing gains
ProbeBandwidth,
/// Periodically drain the queue to measure true min RTT
ProbeRtt,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct RttSample {
rtt: Duration,
time: Instant,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct BwSample {
bytes_per_sec: u64,
time: Instant,
}
/// Maximum number of samples to keep for windowed min/max
const MAX_SAMPLES: usize = 32;
/// Initial congestion window: 10 packets × MTU
const INITIAL_CWND_PACKETS: u64 = 10;
/// Minimum cwnd: 2 packets
const MIN_CWND_PACKETS: u64 = 2;
/// Min RTT expiry window (after which we re-probe)
const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10);
/// ProbeRTT drain duration
const PROBE_RTT_DURATION: Duration = Duration::from_millis(200);
impl CongestionController {
pub fn new(mtu: u64) -> Self {
let now = Instant::now();
let initial_cwnd = INITIAL_CWND_PACKETS * mtu;
Self {
cwnd: initial_cwnd,
ssthresh: u64::MAX,
phase: Phase::SlowStart,
min_rtt: Duration::from_millis(100), // Conservative initial estimate
max_bandwidth: 0,
rtt_samples: VecDeque::with_capacity(MAX_SAMPLES),
bw_samples: VecDeque::with_capacity(MAX_SAMPLES),
bytes_in_flight: 0,
total_acked: 0,
last_ack_time: now,
loss_count: 0,
pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec
mtu,
probe_rtt_timer: None,
min_rtt_stamp: now,
}
}
/// Returns the current congestion window in bytes.
pub fn cwnd(&self) -> u64 {
self.cwnd
}
/// Returns the current congestion window in packets.
pub fn cwnd_packets(&self) -> usize {
(self.cwnd / self.mtu).max(MIN_CWND_PACKETS) as usize
}
/// Returns the current pacing rate in bytes/sec.
pub fn pacing_rate(&self) -> u64 {
self.pacing_rate
}
/// Returns the smoothed RTT estimate.
pub fn smoothed_rtt(&self) -> Duration {
self.min_rtt
}
/// Returns how many bytes can still be sent.
pub fn available_cwnd(&self) -> u64 {
self.cwnd.saturating_sub(self.bytes_in_flight)
}
/// Returns the recommended retransmit budget per tick.
pub fn retransmit_budget(&self) -> usize {
// Allow retransmitting up to 1/4 of the cwnd in packets per tick
let budget = (self.cwnd_packets() / 4).max(2);
budget.min(64) // cap at 64 to prevent burst
}
/// Check whether we can send more data.
pub fn can_send(&self) -> bool {
self.bytes_in_flight < self.cwnd
}
/// Record that we sent `bytes` of data.
pub fn on_send(&mut self, bytes: u64) {
self.bytes_in_flight = self.bytes_in_flight.saturating_add(bytes);
}
/// Record that `bytes` were acknowledged with the given RTT sample.
pub fn on_ack(&mut self, bytes: u64, rtt: Duration) {
let now = Instant::now();
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes);
self.total_acked = self.total_acked.saturating_add(bytes);
// Update RTT
self.update_rtt(rtt, now);
// Update bandwidth estimate
self.update_bandwidth(bytes, now);
// State machine
match self.phase {
Phase::SlowStart => {
// Exponential growth: increase cwnd by acked bytes
self.cwnd = self.cwnd.saturating_add(bytes);
if self.cwnd >= self.ssthresh {
self.phase = Phase::ProbeBandwidth;
tracing::debug!(cwnd = self.cwnd, "congestion: exiting slow start");
}
}
Phase::ProbeBandwidth => {
// TCP Reno Additive Increase: increase cwnd by ~1 MTU per RTT
self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1));
}
Phase::ProbeRtt => {
// Drain down to 4 packets to measure true min RTT
self.cwnd = MIN_CWND_PACKETS * self.mtu * 2;
if let Some(timer) = self.probe_rtt_timer {
if now.duration_since(timer) >= PROBE_RTT_DURATION {
// ProbeRTT complete, return to ProbeBandwidth
self.phase = Phase::ProbeBandwidth;
self.probe_rtt_timer = None;
self.cwnd = (MIN_CWND_PACKETS * self.mtu * 4).max(self.cwnd);
tracing::debug!(cwnd = self.cwnd, min_rtt = ?self.min_rtt, "congestion: probe RTT complete");
}
}
}
}
/*
// Periodically enter ProbeRTT to refresh min_rtt
if now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY && self.phase != Phase::ProbeRtt {
self.phase = Phase::ProbeRtt;
self.probe_rtt_timer = Some(now);
tracing::debug!("congestion: entering probe RTT phase");
}
*/
self.update_pacing_rate();
self.last_ack_time = now;
}
/// Record a loss event.
pub fn on_loss(&mut self, bytes_lost: u64) {
self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes_lost);
self.loss_count += 1;
match self.phase {
Phase::SlowStart => {
// Exit slow start, set ssthresh to half of cwnd
self.ssthresh = self.cwnd / 2;
self.cwnd = self.ssthresh.max(MIN_CWND_PACKETS * self.mtu);
self.phase = Phase::ProbeBandwidth;
tracing::debug!(cwnd = self.cwnd, ssthresh = self.ssthresh, "congestion: loss during slow start");
}
Phase::ProbeBandwidth => {
// Multiplicative decrease: cwnd *= 0.7 (BBR-style, less aggressive than Cubic's 0.5)
self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu);
tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced");
}
Phase::ProbeRtt => {
// Don't react to loss during ProbeRTT
}
}
self.update_pacing_rate();
}
/// Called periodically to update state.
pub fn on_tick(&mut self) {
// Nothing special needed per-tick -- state updates happen on ACK/loss
}
// ── Private ──────────────────────────────────────────────────────────────
fn update_rtt(&mut self, rtt: Duration, now: Instant) {
// Track windowed minimum RTT
if rtt < self.min_rtt || now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY {
self.min_rtt = rtt;
self.min_rtt_stamp = now;
}
// Keep sample history
self.rtt_samples.push_back(RttSample { rtt, time: now });
while self.rtt_samples.len() > MAX_SAMPLES {
self.rtt_samples.pop_front();
}
}
fn update_bandwidth(&mut self, acked_bytes: u64, now: Instant) {
let elapsed = now.duration_since(self.last_ack_time);
if elapsed.as_micros() > 0 {
let bw = acked_bytes * 1_000_000 / elapsed.as_micros() as u64;
if bw > self.max_bandwidth {
self.max_bandwidth = bw;
}
self.bw_samples.push_back(BwSample { bytes_per_sec: bw, time: now });
while self.bw_samples.len() > MAX_SAMPLES {
self.bw_samples.pop_front();
}
}
}
fn bandwidth_delay_product(&self) -> u64 {
// BDP = max_bandwidth * min_rtt
let bw = if self.max_bandwidth > 0 {
self.max_bandwidth
} else {
// Fallback: assume 10 Mbps
1_250_000
};
let rtt_secs = self.min_rtt.as_secs_f64();
(bw as f64 * rtt_secs) as u64
}
fn update_pacing_rate(&mut self) {
// Pacing rate = cwnd / min_rtt (with gain)
let rtt_us = self.min_rtt.as_micros().max(1) as u64;
self.pacing_rate = self.cwnd * 1_000_000 / rtt_us;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_initial_state() {
let cc = CongestionController::new(1200);
assert_eq!(cc.cwnd(), 12000); // 10 * 1200
assert!(cc.can_send());
assert_eq!(cc.cwnd_packets(), 10);
}
#[test]
fn test_slow_start_growth() {
let mut cc = CongestionController::new(1200);
// Simulate sending and ACKing
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(50));
// cwnd should grow
assert!(cc.cwnd() > 12000);
}
#[test]
fn test_loss_reduces_cwnd() {
let mut cc = CongestionController::new(1200);
let initial = cc.cwnd();
cc.on_loss(1200);
assert!(cc.cwnd() < initial);
}
#[test]
fn test_can_send_limits() {
let mut cc = CongestionController::new(1200);
// Send until cwnd is exhausted
for _ in 0..10 {
cc.on_send(1200);
}
assert!(!cc.can_send()); // cwnd exhausted
}
#[test]
fn test_retransmit_budget() {
let cc = CongestionController::new(1200);
let budget = cc.retransmit_budget();
assert!(budget >= 2);
assert!(budget <= 64);
}
#[test]
fn test_rtt_tracking() {
let mut cc = CongestionController::new(1200);
cc.on_send(1200);
cc.on_ack(1200, Duration::from_millis(25));
assert_eq!(cc.smoothed_rtt(), Duration::from_millis(25));
}
}