diff --git a/.gitignore b/.gitignore index c1198ad..da7670d 100644 Binary files a/.gitignore and b/.gitignore differ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 29a54e8..3a2e85d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -22,6 +22,7 @@ By contributing to this project, you agree to abide by our code of conduct and l To build and test OSTP locally, you will need: * **Rust Toolchain**: Install via [rustup](https://rustup.rs/) (stable channel). +* **Go 1.20+**: Required to compile the embedded `dnstt` tunnel binaries. * **Node.js (18+) & npm**: Required to compile Tauri GUI resources. * **Git**: For version control. diff --git a/CONTRIBUTING.ru.md b/CONTRIBUTING.ru.md index 79cdbb3..6fa16ad 100644 --- a/CONTRIBUTING.ru.md +++ b/CONTRIBUTING.ru.md @@ -22,6 +22,7 @@ Для локальной сборки и тестирования OSTP вам понадобятся: * **Rust Toolchain**: Установите через [rustup](https://rustup.rs/) (stable канал). +* **Go 1.20+**: Необходимо для сборки встроенного DNS-туннеля dnstt. * **Node.js (18+) и npm**: Необходимы для сборки интерфейса Tauri. * **Git**: Для контроля версий. diff --git a/README.ru.md b/README.ru.md index 78ae77f..0582885 100644 --- a/README.ru.md +++ b/README.ru.md @@ -142,8 +142,13 @@ irm https://raw.githubusercontent.com/ospab/ostp/master/scripts/install.ps1 | ie ## Сборка из исходников +### Зависимости для сборки + +- Rust 1.70+ +- Go 1.20+ (необходимо для сборки встроенного DNS-туннеля dnstt) + +> **Благодарности:** Этот проект использует [dnstt](https://www.bamsoftware.com/software/dnstt/) от Bamsoftware для обеспечения устойчивого туннелирования поверх DNS. Бинарники dnstt автоматически компилируются и встраиваются в ядро OSTP. ```bash -# Требования: Rust 1.75+ cargo build --release # Кросс-компиляция для Linux diff --git a/ostp-client/src/transport/dns.rs b/ostp-client/src/transport/dns.rs index 600721b..e8ada26 100644 --- a/ostp-client/src/transport/dns.rs +++ b/ostp-client/src/transport/dns.rs @@ -1,230 +1 @@ -/// DNS tunnel transport — dnstt-style implementation. -/// -/// Protocol (client → server, embedded in DNS query domain name): -/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤MAX_CHUNK]) -/// Split into DNS labels of max 63 chars, suffixed with base_domain. -/// -/// Poll query: payload is empty (total_frags=1, frag_idx=0, len=0). -/// -/// Protocol (server → client, in TXT rdata): -/// Concatenated length-prefixed OSTP packets: [len: 2 BE][data ...]... -/// -/// Polling: adaptive 500ms → 10s, like dnstt. Resets to 500ms on real data. -use std::collections::HashMap; -use std::sync::Arc; -use std::time::Duration; -use bytes::Bytes; -use rand::Rng; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, Mutex}; -use crate::transport::Transport; -use rand::RngCore; -use ostp_core::dns::{base32_encode, DnsPacket, DnsRecordType}; - -/// Max raw payload bytes we put into one DNS query. -/// Calculation: FQDN ≤ 253 chars. Domain suffix ~30 chars max. -/// Remaining: ~220 chars for base32 labels. 220/8*5 = 137 bytes raw. -/// Header = 12 bytes → payload ≤ 120 bytes (conservative, works for any domain ≤ 40 chars). -const MAX_CHUNK_PAYLOAD: usize = 120; -const CLIENT_ID_LEN: usize = 8; -const INIT_POLL_DELAY: Duration = Duration::from_millis(500); -const MAX_POLL_DELAY: Duration = Duration::from_secs(10); -const POLL_DELAY_MULTIPLIER: f64 = 2.0; - -pub async fn start_dns_transport( - domain: String, - resolver: String, - _pubkey: Option, -) -> std::io::Result { - let (app_tx, transport_rx) = mpsc::channel::(256); - let (transport_tx, app_rx) = mpsc::channel::(256); - - let resolver_addr = if resolver.contains(':') { - resolver.clone() - } else { - format!("{}:53", resolver) - }; - - let socket = UdpSocket::bind("0.0.0.0:0").await?; - socket.connect(&resolver_addr).await?; - let socket = Arc::new(socket); - - // Generate random ClientID for this tunnel session - let mut client_id = [0u8; CLIENT_ID_LEN]; - rand::thread_rng().fill_bytes(&mut client_id); - let client_id = Arc::new(client_id); - - tracing::info!("DNS transport: domain={} resolver={} client_id={}", - domain, resolver_addr, - hex::encode(client_id.as_slice())); - - // ── Send task ───────────────────────────────────────────────────────────── - let sock_send = socket.clone(); - let cid_send = client_id.clone(); - let domain_send = domain.clone(); - tokio::spawn(async move { - let mut rx = transport_rx; - let mut msg_id: u16 = 0; - let mut poll_delay = INIT_POLL_DELAY; - - loop { - let data: Option = tokio::select! { - data = rx.recv() => data, - _ = tokio::time::sleep(poll_delay) => { - poll_delay = Duration::from_secs_f64( - (poll_delay.as_secs_f64() * POLL_DELAY_MULTIPLIER) - .min(MAX_POLL_DELAY.as_secs_f64()) - ); - // Send poll (empty payload) - Some(Bytes::new()) - } - }; - - let data = match data { - Some(d) => d, - None => { - tracing::debug!("DNS send task: channel closed, exiting"); - break; - } - }; - - if data.is_empty() { - // Poll query — one empty chunk - if let Err(e) = send_chunk(&sock_send, &cid_send, msg_id, 1, 0, &[], &domain_send).await { - tracing::warn!("DNS poll send error: {}", e); - } - } else { - // Real OSTP packet — fragment into chunks - poll_delay = INIT_POLL_DELAY; // reset on real data - - let data_slice = data.as_ref(); - let total_chunks = data_slice.chunks(MAX_CHUNK_PAYLOAD).count(); - let total_u8 = total_chunks.min(255) as u8; - - for (idx, chunk) in data_slice.chunks(MAX_CHUNK_PAYLOAD).enumerate() { - if let Err(e) = send_chunk( - &sock_send, &cid_send, - msg_id, total_u8, idx as u8, - chunk, &domain_send, - ).await { - tracing::warn!("DNS chunk send error (idx={}): {}", idx, e); - break; - } - // Brief inter-fragment delay to avoid flooding the resolver - if total_chunks > 1 { - tokio::time::sleep(Duration::from_millis(20)).await; - } - } - msg_id = msg_id.wrapping_add(1); - } - } - }); - - // ── Receive task ────────────────────────────────────────────────────────── - let sock_recv = socket.clone(); - let tx_recv = transport_tx.clone(); - let domain_recv = domain.clone(); - tokio::spawn(async move { - let mut buf = vec![0u8; 65535]; - // Reassembly buffers: msg_id → (total, Vec>) - let reassembly: HashMap>>)> = HashMap::new(); - - loop { - match sock_recv.recv(&mut buf).await { - Ok(n) => { - let Some(pkt) = DnsPacket::decode(&buf[..n]) else { continue }; - - // Only process DNS responses - if pkt.flags & 0x8000 == 0 { continue; } - - for answer in pkt.answers { - if answer.rtype != DnsRecordType::TXT && answer.rtype != DnsRecordType::NULL { - continue; - } - let rdata = answer.rdata; - // Parse length-prefixed OSTP packets packed in rdata: - // [len_hi: 1][len_lo: 1][data: len]... - let mut pos = 0; - while pos + 2 <= rdata.len() { - let pkt_len = u16::from_be_bytes([rdata[pos], rdata[pos + 1]]) as usize; - pos += 2; - if pkt_len == 0 { continue; } - if pos + pkt_len > rdata.len() { - tracing::debug!("DNS recv: truncated packet in rdata"); - break; - } - let payload = Bytes::copy_from_slice(&rdata[pos..pos + pkt_len]); - pos += pkt_len; - - if tx_recv.send(payload).await.is_err() { - return; // app closed - } - } - } - - // Also check for responses packed in the server's extra DNS answer rdata - // that use our fragmentation scheme (server→client fragments) - // This is handled above via the length-prefix protocol. - let _ = &reassembly; // Keep for future upstream fragmentation support - let _ = &domain_recv; - } - Err(e) => { - tracing::warn!("DNS transport recv error: {}", e); - break; - } - } - } - }); - - Ok(Transport::Dns { - tx: app_tx, - rx: Arc::new(Mutex::new(app_rx)), - }) -} - -/// Build and send one DNS TXT query with a framed chunk. -/// -/// Frame format (before base32 encoding): -/// [client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: 0–120] -async fn send_chunk( - socket: &UdpSocket, - client_id: &[u8; CLIENT_ID_LEN], - msg_id: u16, - total_frags: u8, - frag_idx: u8, - payload: &[u8], - base_domain: &str, -) -> std::io::Result<()> { - // Build frame - let mut frame = Vec::with_capacity(CLIENT_ID_LEN + 4 + payload.len()); - frame.extend_from_slice(client_id); - frame.extend_from_slice(&msg_id.to_be_bytes()); - frame.push(total_frags); - frame.push(frag_idx); - frame.extend_from_slice(payload); - - // Base32-encode - let encoded = base32_encode(&frame); - - // Split into 63-char labels and append domain - let mut fqdn = String::with_capacity(encoded.len() + base_domain.len() + 10); - let mut start = 0; - while start < encoded.len() { - let end = (start + 63).min(encoded.len()); - fqdn.push_str(&encoded[start..end]); - fqdn.push('.'); - start = end; - } - fqdn.push_str(base_domain); - - // Build DNS TXT query with random ID - let id: u16 = rand::thread_rng().gen(); - let pkt = DnsPacket::new_query(id, &fqdn, DnsRecordType::TXT); - let wire = pkt.encode(); - - tracing::trace!("DNS send chunk: msg_id={} frag={}/{} payload={}B fqdn_len={}", - msg_id, frag_idx + 1, total_frags, payload.len(), fqdn.len()); - - socket.send(&wire).await?; - Ok(()) -} +// Left empty by request diff --git a/ostp-client/src/transport/mod.rs b/ostp-client/src/transport/mod.rs index 1af26d9..f4744a4 100644 --- a/ostp-client/src/transport/mod.rs +++ b/ostp-client/src/transport/mod.rs @@ -1,4 +1,3 @@ -pub mod dns; use std::sync::Arc; use tokio::net::UdpSocket; use bytes::Bytes; @@ -10,9 +9,10 @@ pub enum Transport { tx: tokio::sync::mpsc::Sender, rx: Arc>>, }, - Dns { + Dnstt { tx: tokio::sync::mpsc::Sender, rx: Arc>>, + _guard: Arc>, } } @@ -20,7 +20,7 @@ impl Transport { pub async fn send(&self, frame: &Bytes) -> std::io::Result { match self { Self::Udp(sock) => sock.send(frame).await, - Self::Uot { tx, .. } | Self::Dns { tx, .. } => { + Self::Uot { tx, .. } | Self::Dnstt { tx, .. } => { tx.send(frame.clone()).await.map_err(|_| std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed"))?; Ok(frame.len()) } @@ -30,31 +30,40 @@ impl Transport { pub async fn send_to(&self, frame: &Bytes, target: std::net::SocketAddr) -> std::io::Result { match self { Self::Udp(sock) => sock.send_to(frame, target).await, - Self::Uot { .. } | Self::Dns { .. } => self.send(frame).await, + Self::Uot { .. } | Self::Dnstt { .. } => self.send(frame).await, } } pub async fn recv(&self, buf: &mut [u8]) -> std::io::Result { match self { Self::Udp(sock) => sock.recv(buf).await, - Self::Uot { rx, .. } | Self::Dns { rx, .. } => { + Self::Uot { rx, .. } | Self::Dnstt { rx, .. } => { let mut rx = rx.lock().await; - match rx.recv().await { - Some(bytes) => { - let len = bytes.len().min(buf.len()); - buf[..len].copy_from_slice(&bytes[..len]); - Ok(len) - } - None => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "channel closed")), + if let Some(frame) = rx.recv().await { + let len = frame.len().min(buf.len()); + buf[..len].copy_from_slice(&frame[..len]); + Ok(len) + } else { + Err(std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "channel closed")) } } } } + pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, std::net::SocketAddr)> { + match self { + Self::Udp(sock) => sock.recv_from(buf).await, + Self::Uot { .. } | Self::Dnstt { .. } => { + let n = self.recv(buf).await?; + Ok((n, "127.0.0.1:0".parse().unwrap())) + } + } + } + pub fn local_addr(&self) -> std::io::Result { match self { Self::Udp(sock) => sock.local_addr(), - Self::Uot { .. } | Self::Dns { .. } => Ok("0.0.0.0:0".parse().unwrap()), + Self::Uot { .. } | Self::Dnstt { .. } => Ok("0.0.0.0:0".parse().unwrap()), } } } diff --git a/ostp-core/src/lib.rs b/ostp-core/src/lib.rs index 4d52f5d..6e6f9a4 100644 --- a/ostp-core/src/lib.rs +++ b/ostp-core/src/lib.rs @@ -6,6 +6,7 @@ pub mod relay; pub mod resumption; pub mod dns; pub mod dns_prober; +pub mod dnstt; pub use crypto::NoiseRole; pub use framing::{TrafficProfile, PaddingStrategy}; diff --git a/ostp-server/src/lib.rs b/ostp-server/src/lib.rs index e77f91e..5821532 100644 --- a/ostp-server/src/lib.rs +++ b/ostp-server/src/lib.rs @@ -1,15 +1,23 @@ use anyhow::Result; use bytes::Bytes; -use std::collections::HashMap; -use std::net::IpAddr; +use std::collections::{HashMap, VecDeque}; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; use dispatcher::{DispatchOutcome, Dispatcher}; use ostp_core::relay::RelayMessage; use signal::wait_for_shutdown_signal; use tokio::net::UdpSocket; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, RwLock}; use tokio::time::{interval, Duration, Instant}; +use std::sync::OnceLock; + +pub fn dns_queue() -> &'static Arc>>> { + static DNS_QUEUE: OnceLock>>>> = OnceLock::new(); + DNS_QUEUE.get_or_init(|| Arc::new(RwLock::new(HashMap::new()))) +} + mod dispatcher; pub mod outbound; pub mod fallback; @@ -120,6 +128,29 @@ pub async fn run_server( let dispatcher = Dispatcher::new(protocol_config, shared_keys.clone()); + // Launch dnstt-server if configured + let _dnstt_guard = if let Some(dns) = &dns_transport { + let pub_ip = server_public_ip.clone().unwrap_or_else(|| { + let p = config_path.as_ref() + .and_then(|p| p.parent()) + .unwrap_or_else(|| std::path::Path::new(".")) + .join(".ostp_public_ip"); + std::fs::read_to_string(p).unwrap_or_else(|_| "127.0.0.1".to_string()).trim().to_string() + }); + + match ostp_core::dnstt::spawn_server(&pub_ip, 50000, &dns.privkey, debug) { + Ok(guard) => { + tracing::info!("dnstt-server initialized on {}:53 with pubkey: {}", pub_ip, dns.pubkey); + Some(guard) + } + Err(e) => { + tracing::error!("Failed to initialize dnstt-server: {}", e); + None + } + } + } else { + None + }; // Background config hot-reloader for access keys let shared_keys_clone = shared_keys.clone(); let user_stats_clone = dispatcher.user_stats_ref(); @@ -455,17 +486,9 @@ async fn run_server_loop( if let Some(dns_cfg) = dns_transport { if dns_cfg.enabled { - let dns_udp_tx = udp_tx.clone(); - let dns_tcp_map = tcp_map.clone(); - let dns_ui_tx = ui_event_tx.clone(); - tokio::spawn(async move { - crate::transport::dns::start_dns_transport_server( - dns_cfg, - dns_udp_tx, - dns_tcp_map, - dns_ui_tx, - ).await; - }); + // DNS transport is now handled entirely by dnstt-server launched at startup. + // We just trace it here. + tracing::info!("DNS Transport via dnstt is enabled"); } } @@ -585,7 +608,11 @@ async fn handle_udp_packet( if !peer_available.get(&peer_ip).copied().unwrap_or(false) { peer_available.insert(peer_ip, true); let is_tcp = tcp_map.read().await.contains_key(&peer_addr); - let proto = if is_tcp { "TCP (UoT)" } else { "UDP" }; + let is_dns = match peer_ip { + std::net::IpAddr::V4(v4) => v4.octets()[0] == 10 && v4.octets()[1] == 255, + _ => false, + }; + let proto = if is_dns { "DNS-tunnel" } else if is_tcp { "TCP (UoT)" } else { "UDP" }; let _ = ui_event_tx.send(UiEvent::Log(format!("Client {peer_ip} connected via {proto}"))); } @@ -609,7 +636,21 @@ async fn handle_udp_packet( } } if !sent_tcp { - let _ = socket.send_to(&resp, peer_addr).await?; + // Check if this is a DNS tunnel virtual IP (10.255.x.x) + let is_dns_ip = match peer_addr.ip() { + std::net::IpAddr::V4(v4) => v4.octets()[0] == 10 && v4.octets()[1] == 255, + _ => false, + }; + if is_dns_ip { + // Queue the packet for the next DNS poll query + let mut dq = crate::dns_queue().write().await; + let queue = dq.entry(peer_addr).or_insert_with(std::collections::VecDeque::new); + if queue.len() < 256 { + queue.push_back(resp); + } + } else { + let _ = socket.send_to(&resp, peer_addr).await?; + } } let _ = ui_event_tx.send(UiEvent::Tx { peer: peer_ip, bytes: resp_len }); } @@ -636,6 +677,9 @@ async fn handle_udp_packet( ).await?; } } + Ok(DispatchOutcome::Ignored) => { + // Handshake replay, safely ignored + } Err(err) => { let _ = ui_event_tx.send(UiEvent::Log(format!("Protocol error for {peer}: {err}"))); } @@ -672,7 +716,19 @@ async fn handle_tick( } } if !sent_tcp { - let _ = socket.send_to(&frame, peer_addr).await?; + let is_dns_ip = match peer_addr.ip() { + std::net::IpAddr::V4(v4) => v4.octets()[0] == 10 && v4.octets()[1] == 255, + _ => false, + }; + if is_dns_ip { + let mut dq = crate::dns_queue().write().await; + let queue = dq.entry(peer_addr).or_insert_with(std::collections::VecDeque::new); + if queue.len() < 256 { + queue.push_back(frame); + } + } else { + let _ = socket.send_to(&frame, peer_addr).await; + } } } for sid in dropped_sessions { diff --git a/ostp-server/src/relay.rs b/ostp-server/src/relay.rs index 64fa5ae..c802bad 100644 --- a/ostp-server/src/relay.rs +++ b/ostp-server/src/relay.rs @@ -247,18 +247,58 @@ pub async fn send_relay_to_stream( tcp_map: &std::sync::Arc>>>, ) -> Result<()> { let payload = Bytes::from(msg.encode()); - if let Some((frame, peer_addr)) = dispatcher.outbound_to_session(session_id, stream_id, payload)? { + for (frame, peer_addr) in dispatcher.outbound_to_session(session_id, stream_id, payload)? { let response_len = frame.len(); let mut sent_tcp = false; { let map = tcp_map.read().await; if let Some(tx) = map.get(&peer_addr) { - let _ = tx.try_send(frame.clone()); - sent_tcp = true; + // Use a bounded async send with a generous timeout instead of try_send. + // try_send silently drops frames when the channel is full (common with + // bursty traffic), causing spurious retransmits and throughput collapse. + // 200ms matches roughly one RTO — if we can't deliver in that window + // the receiver is definitely stalled and we should log it. + let tx = tx.clone(); + let frame_clone = frame.clone(); + match tokio::time::timeout( + std::time::Duration::from_millis(200), + tx.send(frame_clone), + ).await { + Ok(Ok(())) => { sent_tcp = true; } + Ok(Err(_)) => { + tracing::warn!( + "relay: TCP channel closed for peer={}, frame dropped (session={}, stream={})", + peer_addr, session_id, stream_id + ); + sent_tcp = true; // channel gone, don't fall through to UDP + } + Err(_timeout) => { + tracing::warn!( + "relay: TCP channel full / timeout for peer={}, falling back to UDP (session={}, stream={})", + peer_addr, session_id, stream_id + ); + // sent_tcp stays false → will fall through to UDP send below + } + } } } if !sent_tcp { - let _ = socket.send_to(&frame, peer_addr).await?; + let is_dns_ip = match peer_addr.ip() { + std::net::IpAddr::V4(v4) => v4.octets()[0] == 10 && v4.octets()[1] == 255, + _ => false, + }; + if is_dns_ip { + // DNS virtual IP — queue for next poll + let mut dq = crate::dns_queue().write().await; + let queue = dq.entry(peer_addr).or_insert_with(std::collections::VecDeque::new); + if queue.len() < 256 { + queue.push_back(frame); + } else { + tracing::warn!("relay: dns_queue full for peer={}, frame dropped", peer_addr); + } + } else { + let _ = socket.send_to(&frame, peer_addr).await; + } } let _ = ui_event_tx.send(UiEvent::Tx { peer: peer_addr.ip(), @@ -267,3 +307,4 @@ pub async fn send_relay_to_stream( } Ok(()) } + diff --git a/ostp-server/src/transport/dns.rs b/ostp-server/src/transport/dns.rs index 4960189..e8ada26 100644 --- a/ostp-server/src/transport/dns.rs +++ b/ostp-server/src/transport/dns.rs @@ -1,346 +1 @@ -/// DNS tunnel transport — dnstt-style server implementation. -/// -/// Each DNS TXT query from client contains a framed chunk: -/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤120]) -/// -/// Server: -/// 1. Decodes ClientID + fragment from query name -/// 2. Reassembles fragments per (client_id, msg_id) -/// 3. Forwards complete OSTP packet to dispatcher (udp_tx) -/// 4. Waits up to MAX_RESPONSE_DELAY for responses -/// 5. Bundles responses as length-prefixed packets in DNS TXT answer -/// -/// Server → client data in TXT rdata: [len_hi][len_lo][data...]... -use std::collections::HashMap; -use std::net::SocketAddr; -use std::sync::Arc; -use bytes::Bytes; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, RwLock}; -use tokio::time::Duration; - -use ostp_core::dns::{base32_decode, DnsPacket, DnsRecordType}; -use crate::config::DnsTransportConfig; -use crate::UiEvent; - -const CLIENT_ID_LEN: usize = 8; -const HEADER_LEN: usize = CLIENT_ID_LEN + 4; // client_id + msg_id(2) + total(1) + idx(1) -/// How long to wait for downstream OSTP data before sending an empty response. -const MAX_RESPONSE_DELAY: Duration = Duration::from_millis(800); -/// Maximum number of response packets to bundle into one DNS answer. -const MAX_RESPONSE_PACKETS: usize = 8; -/// How long to keep per-client reassembly state without activity. -const CLIENT_EXPIRY: Duration = Duration::from_secs(30); - -#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] -struct ClientId([u8; CLIENT_ID_LEN]); - -struct ReassemblyState { - total: u8, - frags: Vec>>, - received: u8, -} - -impl ReassemblyState { - fn new(total: u8) -> Self { - Self { - total, - frags: vec![None; total as usize], - received: 0, - } - } - - fn insert(&mut self, idx: u8, payload: Vec) -> bool { - let idx = idx as usize; - if idx >= self.frags.len() { return false; } - if self.frags[idx].is_none() { - self.frags[idx] = Some(payload); - self.received += 1; - } - self.received >= self.total - } - - fn assemble(self) -> Option> { - let mut out = Vec::new(); - for frag in self.frags { - out.extend_from_slice(&frag?); - } - Some(out) - } -} - -struct ClientState { - /// msg_id → reassembly buffer - reassembly: HashMap, - /// Channel to push pending responses into; DNS handler polls this per-query - #[allow(dead_code)] - resp_tx: mpsc::Sender, - last_seen: std::time::Instant, -} - -pub(crate) async fn start_dns_transport_server( - config: DnsTransportConfig, - udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, - tcp_map: Arc>>>, - ui_event_tx: mpsc::UnboundedSender, -) { - let listen_addr = if config.listen.contains(':') { - config.listen.clone() - } else { - format!("0.0.0.0:{}", config.listen) - }; - - let socket = match UdpSocket::bind(&listen_addr).await { - Ok(s) => Arc::new(s), - Err(e) => { - tracing::error!("DNS Transport failed to bind to {}: {}", listen_addr, e); - let _ = ui_event_tx.send(UiEvent::Log(format!("DNS Transport failed to bind: {}", e))); - return; - } - }; - - tracing::info!("DNS Transport listening on {}", listen_addr); - let _ = ui_event_tx.send(UiEvent::Log(format!("DNS Transport listening on {}", listen_addr))); - - // Per-client state: ClientId → ClientState - // Access is serialised by a single Mutex so fragments from the same client - // are always reassembled atomically. - let clients: Arc>> = - Arc::new(tokio::sync::Mutex::new(HashMap::new())); - - // Cleanup task: evict stale client state - { - let clients_gc = clients.clone(); - tokio::spawn(async move { - loop { - tokio::time::sleep(Duration::from_secs(15)).await; - let mut map = clients_gc.lock().await; - map.retain(|_, v| v.last_seen.elapsed() < CLIENT_EXPIRY); - } - }); - } - - let base_domain = config.domain.clone(); - let mut buf = vec![0u8; 65535]; - - loop { - let (size, peer) = match socket.recv_from(&mut buf).await { - Ok(v) => v, - Err(e) => { - tracing::warn!("DNS Transport recv error: {}", e); - continue; - } - }; - - let packet_bytes = buf[..size].to_vec(); - let udp_tx = udp_tx.clone(); - let tcp_map = tcp_map.clone(); - let socket = socket.clone(); - let clients = clients.clone(); - let base_domain = base_domain.clone(); - - tokio::spawn(async move { - handle_dns_query( - packet_bytes, peer, - udp_tx, tcp_map, socket, clients, base_domain, - ).await; - }); - } -} - -async fn handle_dns_query( - packet_bytes: Vec, - peer: SocketAddr, - udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, - tcp_map: Arc>>>, - socket: Arc, - clients: Arc>>, - base_domain: String, -) { - let dns_req = match DnsPacket::decode(&packet_bytes) { - Some(p) => p, - None => { - tracing::debug!("DNS: failed to decode packet from {}", peer); - return; - } - }; - - if dns_req.questions.is_empty() { return; } - let query = &dns_req.questions[0]; - - if query.qtype != DnsRecordType::TXT && query.qtype != DnsRecordType::NULL { - let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), vec![]); - let _ = socket.send_to(&resp, peer).await; - return; - } - if !query.name.ends_with(&base_domain) { - let mut resp = DnsPacket::new_response(dns_req.id, &query.name, query.qtype.clone(), vec![]); - resp.flags = 0x8183; // NXDOMAIN - let _ = socket.send_to(&resp.encode(), peer).await; - return; - } - - // Strip base domain and labels separator to get base32 subdomain - let subdomain = { - let name_lower = query.name.to_lowercase(); - let suffix = format!(".{}", base_domain.to_lowercase()); - let suffix_bare = base_domain.to_lowercase(); - let stripped = if name_lower.ends_with(&suffix) { - &query.name[..name_lower.len() - suffix.len()] - } else if name_lower == suffix_bare { - "" - } else { - return; - }; - // Remove dots (label separators) to get contiguous base32 - stripped.replace('.', "") - }; - - if subdomain.is_empty() { - // Pure poll — no payload - let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), vec![]); - let _ = socket.send_to(&resp, peer).await; - return; - } - - // Base32-decode - let raw = match base32_decode(&subdomain) { - Some(b) => b, - None => { - tracing::debug!("DNS: base32 decode failed from {}", peer); - return; - } - }; - - if raw.len() < HEADER_LEN { - tracing::debug!("DNS: frame too short ({} bytes) from {}", raw.len(), peer); - return; - } - - // Parse header - let client_id = ClientId(raw[..CLIENT_ID_LEN].try_into().unwrap()); - let msg_id = u16::from_be_bytes([raw[8], raw[9]]); - let total_frags = raw[10]; - let frag_idx = raw[11]; - let payload = raw[HEADER_LEN..].to_vec(); - - let fake_peer = client_id_to_fake_addr(&client_id); - - tracing::trace!("DNS: client={} msg={} frag={}/{} payload={}B", - hex::encode(&client_id.0), msg_id, frag_idx + 1, total_frags, payload.len()); - - // ── Reassembly ──────────────────────────────────────────────────────────── - let complete_packet: Option> = { - let mut map = clients.lock().await; - let state = map.entry(client_id).or_insert_with(|| { - let (resp_tx, _) = mpsc::channel(64); // placeholder, will be replaced below - ClientState { - reassembly: HashMap::new(), - resp_tx, - last_seen: std::time::Instant::now(), - } - }); - state.last_seen = std::time::Instant::now(); - - if total_frags == 0 { - // Empty poll — no data - None - } else if total_frags == 1 && payload.is_empty() { - // Poll with empty payload - None - } else { - let asm = state.reassembly - .entry(msg_id) - .or_insert_with(|| ReassemblyState::new(total_frags)); - - if asm.insert(frag_idx, payload) { - // All fragments received — assemble and remove - let complete = state.reassembly.remove(&msg_id) - .and_then(|s| s.assemble()); - complete - } else { - None - } - } - }; - - // ── Create per-query response channel ──────────────────────────────────── - // We use the stable fake_peer as the routing key in tcp_map. - // For each query we create a fresh one-shot channel. - let (resp_tx, mut resp_rx) = mpsc::channel::(MAX_RESPONSE_PACKETS); - tcp_map.write().await.insert(fake_peer, resp_tx.clone()); - - // ── Forward complete OSTP packet to dispatcher ──────────────────────────── - if let Some(ostp_pkt) = complete_packet { - tracing::debug!("DNS: forwarding {}B OSTP packet from client={} to dispatcher", - ostp_pkt.len(), hex::encode(&client_id.0)); - let _ = udp_tx.send((Bytes::from(ostp_pkt), fake_peer)).await; - } - - // ── Wait for OSTP response(s) ───────────────────────────────────────────── - let mut responses: Vec = Vec::new(); - let deadline = tokio::time::sleep(MAX_RESPONSE_DELAY); - tokio::pin!(deadline); - - loop { - tokio::select! { - _ = &mut deadline => break, - resp = resp_rx.recv() => { - match resp { - Some(r) => { - responses.push(r); - if responses.len() >= MAX_RESPONSE_PACKETS { break; } - } - None => break, - } - } - } - } - - // Only remove if it's still our channel - { - let mut map = tcp_map.write().await; - if let Some(existing_tx) = map.get(&fake_peer) { - if existing_tx.same_channel(&resp_tx) { - map.remove(&fake_peer); - } - } - } - - // ── Build DNS TXT response ──────────────────────────────────────────────── - // Bundle all response packets as length-prefixed data in TXT rdata: - // [len_hi][len_lo][data...]... - let mut rdata: Vec = Vec::new(); - for r in &responses { - let len = r.len() as u16; - rdata.push((len >> 8) as u8); - rdata.push((len & 0xFF) as u8); - rdata.extend_from_slice(r); - } - - tracing::trace!("DNS: responding to {} with {} OSTP packets ({} bytes rdata)", - peer, responses.len(), rdata.len()); - - let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), rdata); - let _ = socket.send_to(&resp, peer).await; -} - -/// Build a DNS response packet with the given TXT rdata. -fn build_dns_response( - req: &DnsPacket, - name: &str, - rtype: DnsRecordType, - rdata: Vec, -) -> Vec { - let resp = DnsPacket::new_response(req.id, name, rtype, rdata); - resp.encode() -} - -fn client_id_to_fake_addr(client_id: &ClientId) -> SocketAddr { - let mut ip_bytes = [10, 255, 0, 0]; - ip_bytes[2] = client_id.0[0]; - ip_bytes[3] = client_id.0[1]; - let port = u16::from_be_bytes([client_id.0[2], client_id.0[3]]); - let port = if port == 0 { 1 } else { port }; - SocketAddr::from((ip_bytes, port)) -} +// Left empty by request diff --git a/ostp-server/src/transport/mod.rs b/ostp-server/src/transport/mod.rs index 1f5b018..6433eb6 100644 --- a/ostp-server/src/transport/mod.rs +++ b/ostp-server/src/transport/mod.rs @@ -1,2 +1 @@ pub mod uot; -pub mod dns; diff --git a/ostp/src/main.rs b/ostp/src/main.rs index 4f3ed87..2c78720 100644 --- a/ostp/src/main.rs +++ b/ostp/src/main.rs @@ -1464,8 +1464,15 @@ async fn run_app() -> Result<()> { if let Some(ref mode_str) = args.init { let is_server = mode_str == "server"; let key = generate_secure_key("hex"); - let dns_pub = generate_secure_key("base64"); - let dns_priv = generate_secure_key("base64"); + + let (dns_priv, dns_pub) = if is_server { + ostp_core::dnstt::generate_keypair().unwrap_or_else(|e| { + tracing::warn!("Failed to generate dnstt keys: {}. Using placeholders.", e); + ("YOUR_PRIVKEY".to_string(), "YOUR_PUBKEY".to_string()) + }) + } else { + ("".to_string(), "".to_string()) + }; let content = if is_server { format!(r#"{{ // OSTP Server Configuration