ostp/ostp-server/src/dispatcher.rs

298 lines
11 KiB
Rust

use anyhow::Result;
use bytes::Bytes;
use ostp_core::{OstpEvent, ProtocolAction, ProtocolConfig, ProtocolMachine};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
pub enum DispatchOutcome {
Unauthorized,
Accepted {
responses: Vec<Bytes>,
app_payloads: Vec<(u32, u16, Bytes)>, // session_id, stream_id, payload
peer_addr: SocketAddr,
},
}
pub struct PeerState {
pub machine: ProtocolMachine,
pub last_addr: SocketAddr,
pub obfuscation_key: [u8; 8],
pub last_seen: std::time::Instant,
}
pub struct Dispatcher {
peer_machines: HashMap<u32, PeerState>,
addr_to_session: HashMap<SocketAddr, u32>,
machine_config: ProtocolConfig,
access_keys: Arc<RwLock<HashMap<String, ()>>>,
replay_cache: std::collections::HashMap<Vec<u8>, u64>,
roaming_tokens: f64,
last_token_regen: std::time::Instant,
}
impl Dispatcher {
pub fn new(machine_config: ProtocolConfig, access_keys: Arc<RwLock<HashMap<String, ()>>>) -> Self {
Self {
peer_machines: HashMap::new(),
addr_to_session: HashMap::new(),
machine_config,
access_keys,
replay_cache: std::collections::HashMap::new(),
roaming_tokens: 50.0,
last_token_regen: std::time::Instant::now(),
}
}
pub fn on_datagram(&mut self, peer: SocketAddr, packet: Bytes) -> Result<DispatchOutcome> {
if packet.len() < 4 {
return Ok(DispatchOutcome::Unauthorized);
}
let mut session_id_opt = None;
if let Some(&sid) = self.addr_to_session.get(&peer) {
if let Some(peer_state) = self.peer_machines.get(&sid) {
let mut header = [0u8; 12];
if packet.len() >= 12 {
header.copy_from_slice(&packet[0..12]);
ostp_core::crypto::deobfuscate_packet_inplace(&mut header, &peer_state.obfuscation_key, false);
let candidate_sid = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
if candidate_sid == sid {
session_id_opt = Some(sid);
}
}
}
}
if session_id_opt.is_none() {
// Token Bucket rate limiter: mitigate seamless roaming CPU DoS vector
let now = std::time::Instant::now();
let elapsed = now.duration_since(self.last_token_regen).as_secs_f64();
self.last_token_regen = now;
self.roaming_tokens = (self.roaming_tokens + elapsed * 50.0).min(50.0);
if self.roaming_tokens >= 1.0 {
self.roaming_tokens -= 1.0;
// Try seamless roaming over all peers
for (&sid, peer_state) in &self.peer_machines {
if packet.len() >= 12 {
let mut header = [0u8; 12];
header.copy_from_slice(&packet[0..12]);
ostp_core::crypto::deobfuscate_packet_inplace(&mut header, &peer_state.obfuscation_key, false);
let candidate_sid = u32::from_be_bytes([header[0], header[1], header[2], header[3]]);
if candidate_sid == sid {
session_id_opt = Some(sid);
break;
}
}
}
}
}
if let Some(session_id) = session_id_opt {
if let Some(peer_state) = self.peer_machines.get_mut(&session_id) {
peer_state.last_addr = peer;
peer_state.last_seen = std::time::Instant::now();
self.addr_to_session.insert(peer, session_id);
let action = match peer_state.machine.on_event(OstpEvent::Inbound(packet)) {
Ok(a) => a,
Err(_) => return Ok(DispatchOutcome::Unauthorized),
};
let mut responses = Vec::new();
let mut app_payloads = Vec::new();
fn collect_action(
act: ProtocolAction,
sid: u32,
resps: &mut Vec<Bytes>,
loads: &mut Vec<(u32, u16, Bytes)>,
) {
match act {
ProtocolAction::SendDatagram(frame) => {
resps.push(frame);
}
ProtocolAction::DeliverApp(stream_id, data) => {
loads.push((sid, stream_id, data));
}
ProtocolAction::Multiple(list) => {
for item in list {
collect_action(item, sid, resps, loads);
}
}
_ => {}
}
}
collect_action(action, session_id, &mut responses, &mut app_payloads);
return Ok(DispatchOutcome::Accepted {
responses,
app_payloads,
peer_addr: peer,
});
}
}
// Not an existing session — try each registered access key's derived obfuscation key
let keys_snapshot: Vec<String> = self.access_keys.read().unwrap().keys().cloned().collect();
for candidate_key in keys_snapshot {
let obf_key = ostp_core::crypto::derive_obfuscation_key(candidate_key.as_bytes());
let psk = ostp_core::crypto::derive_psk(candidate_key.as_bytes());
// Decode the session_id using this key's obfuscation
let mut header = [0u8; 4];
header.copy_from_slice(&packet[0..4]);
ostp_core::crypto::deobfuscate_packet_inplace(&mut header, &obf_key, true);
let candidate_session_id = u32::from_be_bytes(header);
let mut cfg = self.machine_config.clone();
cfg.session_id = candidate_session_id;
cfg.psk = psk;
cfg.handshake_payload = vec![];
cfg.obfuscation_key = obf_key;
let mut machine = match ProtocolMachine::new(cfg) {
Ok(m) => m,
Err(_) => continue,
};
let action = match machine.on_event(OstpEvent::Inbound(packet.clone())) {
Ok(a) => a,
Err(_) => continue,
};
if let ProtocolAction::HandshakePayload(payload, response_opt) = action {
if payload.len() >= 12 {
let mut ts_bytes = [0_u8; 8];
ts_bytes.copy_from_slice(&payload[..8]);
let ts = u64::from_be_bytes(ts_bytes);
let mut sid_bytes = [0_u8; 4];
sid_bytes.copy_from_slice(&payload[8..12]);
let sid_from_payload = u32::from_be_bytes(sid_bytes);
if sid_from_payload != candidate_session_id {
continue;
}
let key_bytes = &payload[12..];
if let Ok(key_from_payload) = std::str::from_utf8(key_bytes) {
// The key embedded in the payload must match the candidate key we decoded with
if key_from_payload != candidate_key {
continue;
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let drift = (now as i64 - ts as i64).abs();
if drift > 300 {
// Narrow window (5 mins) limits replay risk and bounds cache memory
continue;
}
if !self.replay_cache.contains_key(&payload.to_vec()) {
self.replay_cache.insert(payload.to_vec(), ts);
machine.set_session_keys(candidate_session_id, obf_key);
self.peer_machines.insert(candidate_session_id, PeerState {
machine,
last_addr: peer,
obfuscation_key: obf_key,
last_seen: std::time::Instant::now(),
});
self.addr_to_session.insert(peer, candidate_session_id);
return Ok(DispatchOutcome::Accepted {
responses: response_opt.into_iter().collect(),
app_payloads: Vec::new(),
peer_addr: peer,
});
}
}
}
}
}
Ok(DispatchOutcome::Unauthorized)
}
pub fn outbound_to_session(&mut self, session_id: u32, stream_id: u16, payload: Bytes) -> Result<Option<(Bytes, SocketAddr)>> {
let peer_state = if let Some(existing) = self.peer_machines.get_mut(&session_id) {
existing
} else {
return Ok(None);
};
let addr = peer_state.last_addr;
match peer_state.machine.on_event(OstpEvent::Outbound(stream_id, payload))? {
ProtocolAction::SendDatagram(frame) => Ok(Some((frame, addr))),
_ => Ok(None),
}
}
pub fn on_tick(&mut self) -> (Vec<(Bytes, SocketAddr)>, Vec<u32>) {
// Purge expired handshakes from replay cache (older than 5 min drift allowance)
let current_sys_time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
self.replay_cache.retain(|_, &mut ts| (current_sys_time as i64 - ts as i64).abs() <= 300);
let mut frames = Vec::new();
let mut expired = Vec::new();
let now = std::time::Instant::now();
let timeout_dur = std::time::Duration::from_secs(300); // 5 minutes session timeout
// Gather expired sessions
for (&sid, peer_state) in &self.peer_machines {
if now.duration_since(peer_state.last_seen) > timeout_dur {
expired.push(sid);
}
}
// Clear expired sessions from internal state
for sid in &expired {
self.drop_session(*sid);
}
// Drive ticks for remaining active sessions
for peer_state in self.peer_machines.values_mut() {
let action = match peer_state.machine.on_event(OstpEvent::Tick) {
Ok(a) => a,
Err(_) => continue,
};
let mut queue = vec![action];
while let Some(current) = queue.pop() {
match current {
ProtocolAction::Multiple(list) => {
for item in list {
queue.push(item);
}
}
ProtocolAction::SendDatagram(frame) => {
frames.push((frame, peer_state.last_addr));
}
_ => {}
}
}
}
(frames, expired)
}
pub fn drop_session(&mut self, session_id: u32) {
if let Some(state) = self.peer_machines.remove(&session_id) {
self.addr_to_session.remove(&state.last_addr);
}
}
}