From 3ced4a19b60928d7a1ecae4d8bd6eea0e1b89eb0 Mon Sep 17 00:00:00 2001 From: ospab Date: Sat, 20 Jun 2026 18:45:23 +0300 Subject: [PATCH] Rewrite DNS transport with dnstt-style fragmentation, ClientID, polling and reassembly --- ostp-client/src/transport/dns.rs | 226 +++++++++++++++----- ostp-server/src/transport/dns.rs | 352 ++++++++++++++++++++++++------- 2 files changed, 455 insertions(+), 123 deletions(-) diff --git a/ostp-client/src/transport/dns.rs b/ostp-client/src/transport/dns.rs index e148cc2..600721b 100644 --- a/ostp-client/src/transport/dns.rs +++ b/ostp-client/src/transport/dns.rs @@ -1,19 +1,43 @@ +/// DNS tunnel transport — dnstt-style implementation. +/// +/// Protocol (client → server, embedded in DNS query domain name): +/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤MAX_CHUNK]) +/// Split into DNS labels of max 63 chars, suffixed with base_domain. +/// +/// Poll query: payload is empty (total_frags=1, frag_idx=0, len=0). +/// +/// Protocol (server → client, in TXT rdata): +/// Concatenated length-prefixed OSTP packets: [len: 2 BE][data ...]... +/// +/// Polling: adaptive 500ms → 10s, like dnstt. Resets to 500ms on real data. +use std::collections::HashMap; use std::sync::Arc; use std::time::Duration; use bytes::Bytes; +use rand::Rng; use tokio::net::UdpSocket; use tokio::sync::{mpsc, Mutex}; -use rand::Rng; - -pub use ostp_core::dns::{ - DnsPacket, DnsRecordType, encode_payload_to_domain, - decode_domain_to_payload, -}; use crate::transport::Transport; +use rand::RngCore; +use ostp_core::dns::{base32_encode, DnsPacket, DnsRecordType}; -pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Option) -> std::io::Result { - let (app_tx, transport_rx) = mpsc::channel::(100); - let (transport_tx, app_rx) = mpsc::channel::(100); +/// Max raw payload bytes we put into one DNS query. +/// Calculation: FQDN ≤ 253 chars. Domain suffix ~30 chars max. +/// Remaining: ~220 chars for base32 labels. 220/8*5 = 137 bytes raw. +/// Header = 12 bytes → payload ≤ 120 bytes (conservative, works for any domain ≤ 40 chars). +const MAX_CHUNK_PAYLOAD: usize = 120; +const CLIENT_ID_LEN: usize = 8; +const INIT_POLL_DELAY: Duration = Duration::from_millis(500); +const MAX_POLL_DELAY: Duration = Duration::from_secs(10); +const POLL_DELAY_MULTIPLIER: f64 = 2.0; + +pub async fn start_dns_transport( + domain: String, + resolver: String, + _pubkey: Option, +) -> std::io::Result { + let (app_tx, transport_rx) = mpsc::channel::(256); + let (transport_tx, app_rx) = mpsc::channel::(256); let resolver_addr = if resolver.contains(':') { resolver.clone() @@ -25,69 +49,124 @@ pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Opti socket.connect(&resolver_addr).await?; let socket = Arc::new(socket); - let sock_rx = socket.clone(); - let sock_tx = socket; - let base_domain = domain.clone(); + // Generate random ClientID for this tunnel session + let mut client_id = [0u8; CLIENT_ID_LEN]; + rand::thread_rng().fill_bytes(&mut client_id); + let client_id = Arc::new(client_id); - // Send task (reads from app, encodes into DNS TXT, sends to UDP socket) + tracing::info!("DNS transport: domain={} resolver={} client_id={}", + domain, resolver_addr, + hex::encode(client_id.as_slice())); + + // ── Send task ───────────────────────────────────────────────────────────── + let sock_send = socket.clone(); + let cid_send = client_id.clone(); + let domain_send = domain.clone(); tokio::spawn(async move { let mut rx = transport_rx; + let mut msg_id: u16 = 0; + let mut poll_delay = INIT_POLL_DELAY; + loop { - let data_opt = tokio::select! { - res = rx.recv() => res, - _ = tokio::time::sleep(Duration::from_secs(2)) => Some(Bytes::new()), + let data: Option = tokio::select! { + data = rx.recv() => data, + _ = tokio::time::sleep(poll_delay) => { + poll_delay = Duration::from_secs_f64( + (poll_delay.as_secs_f64() * POLL_DELAY_MULTIPLIER) + .min(MAX_POLL_DELAY.as_secs_f64()) + ); + // Send poll (empty payload) + Some(Bytes::new()) + } }; - - let data = match data_opt { + + let data = match data { Some(d) => d, - None => break, // App closed + None => { + tracing::debug!("DNS send task: channel closed, exiting"); + break; + } }; - // Encode data to base32 domain - let fqdn = encode_payload_to_domain(&data, &base_domain); - let id: u16 = rand::thread_rng().gen(); - - // Randomly choose TXT or NULL for diversity (as requested) - let qtype = if rand::thread_rng().gen_bool(0.5) { - DnsRecordType::TXT + if data.is_empty() { + // Poll query — one empty chunk + if let Err(e) = send_chunk(&sock_send, &cid_send, msg_id, 1, 0, &[], &domain_send).await { + tracing::warn!("DNS poll send error: {}", e); + } } else { - DnsRecordType::NULL - }; + // Real OSTP packet — fragment into chunks + poll_delay = INIT_POLL_DELAY; // reset on real data - let packet = DnsPacket::new_query(id, &fqdn, qtype); - let encoded = packet.encode(); + let data_slice = data.as_ref(); + let total_chunks = data_slice.chunks(MAX_CHUNK_PAYLOAD).count(); + let total_u8 = total_chunks.min(255) as u8; - if let Err(e) = sock_tx.send(&encoded).await { - tracing::warn!("DNS transport send error: {}", e); - break; + for (idx, chunk) in data_slice.chunks(MAX_CHUNK_PAYLOAD).enumerate() { + if let Err(e) = send_chunk( + &sock_send, &cid_send, + msg_id, total_u8, idx as u8, + chunk, &domain_send, + ).await { + tracing::warn!("DNS chunk send error (idx={}): {}", idx, e); + break; + } + // Brief inter-fragment delay to avoid flooding the resolver + if total_chunks > 1 { + tokio::time::sleep(Duration::from_millis(20)).await; + } + } + msg_id = msg_id.wrapping_add(1); } } }); - // Receive task (reads from UDP socket, decodes DNS answer, sends to app) - let _base_domain_rx = domain.clone(); + // ── Receive task ────────────────────────────────────────────────────────── + let sock_recv = socket.clone(); + let tx_recv = transport_tx.clone(); + let domain_recv = domain.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 65535]; + // Reassembly buffers: msg_id → (total, Vec>) + let reassembly: HashMap>>)> = HashMap::new(); + loop { - match sock_rx.recv(&mut buf).await { + match sock_recv.recv(&mut buf).await { Ok(n) => { - if let Some(packet) = DnsPacket::decode(&buf[..n]) { - for answer in packet.answers { - if answer.rtype == DnsRecordType::TXT || answer.rtype == DnsRecordType::NULL { - // If it's a TXT record, the response might be base32 encoded payload? - // Actually, dnstt puts the payload in the TXT/NULL record data. - // We'll just assume the rdata is the raw payload, or base32 encoded if it was sent as such. - // Let's just pass the raw data (TXT strings are decoded in DnsPacket::decode) - - // Wait, dnstt server responds with raw bytes in NULL, and base32/chunked strings in TXT. - // Our `DnsPacket::decode` already handles extracting TXT string bytes or NULL raw bytes into `rdata`. - // Let's just send `rdata` to the app. - if transport_tx.send(Bytes::from(answer.rdata)).await.is_err() { - return; // App closed - } + let Some(pkt) = DnsPacket::decode(&buf[..n]) else { continue }; + + // Only process DNS responses + if pkt.flags & 0x8000 == 0 { continue; } + + for answer in pkt.answers { + if answer.rtype != DnsRecordType::TXT && answer.rtype != DnsRecordType::NULL { + continue; + } + let rdata = answer.rdata; + // Parse length-prefixed OSTP packets packed in rdata: + // [len_hi: 1][len_lo: 1][data: len]... + let mut pos = 0; + while pos + 2 <= rdata.len() { + let pkt_len = u16::from_be_bytes([rdata[pos], rdata[pos + 1]]) as usize; + pos += 2; + if pkt_len == 0 { continue; } + if pos + pkt_len > rdata.len() { + tracing::debug!("DNS recv: truncated packet in rdata"); + break; + } + let payload = Bytes::copy_from_slice(&rdata[pos..pos + pkt_len]); + pos += pkt_len; + + if tx_recv.send(payload).await.is_err() { + return; // app closed } } } + + // Also check for responses packed in the server's extra DNS answer rdata + // that use our fragmentation scheme (server→client fragments) + // This is handled above via the length-prefix protocol. + let _ = &reassembly; // Keep for future upstream fragmentation support + let _ = &domain_recv; } Err(e) => { tracing::warn!("DNS transport recv error: {}", e); @@ -102,3 +181,50 @@ pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Opti rx: Arc::new(Mutex::new(app_rx)), }) } + +/// Build and send one DNS TXT query with a framed chunk. +/// +/// Frame format (before base32 encoding): +/// [client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: 0–120] +async fn send_chunk( + socket: &UdpSocket, + client_id: &[u8; CLIENT_ID_LEN], + msg_id: u16, + total_frags: u8, + frag_idx: u8, + payload: &[u8], + base_domain: &str, +) -> std::io::Result<()> { + // Build frame + let mut frame = Vec::with_capacity(CLIENT_ID_LEN + 4 + payload.len()); + frame.extend_from_slice(client_id); + frame.extend_from_slice(&msg_id.to_be_bytes()); + frame.push(total_frags); + frame.push(frag_idx); + frame.extend_from_slice(payload); + + // Base32-encode + let encoded = base32_encode(&frame); + + // Split into 63-char labels and append domain + let mut fqdn = String::with_capacity(encoded.len() + base_domain.len() + 10); + let mut start = 0; + while start < encoded.len() { + let end = (start + 63).min(encoded.len()); + fqdn.push_str(&encoded[start..end]); + fqdn.push('.'); + start = end; + } + fqdn.push_str(base_domain); + + // Build DNS TXT query with random ID + let id: u16 = rand::thread_rng().gen(); + let pkt = DnsPacket::new_query(id, &fqdn, DnsRecordType::TXT); + let wire = pkt.encode(); + + tracing::trace!("DNS send chunk: msg_id={} frag={}/{} payload={}B fqdn_len={}", + msg_id, frag_idx + 1, total_frags, payload.len(), fqdn.len()); + + socket.send(&wire).await?; + Ok(()) +} diff --git a/ostp-server/src/transport/dns.rs b/ostp-server/src/transport/dns.rs index 2fced0c..2cbaac2 100644 --- a/ostp-server/src/transport/dns.rs +++ b/ostp-server/src/transport/dns.rs @@ -1,15 +1,83 @@ -use std::sync::Arc; -use tokio::net::UdpSocket; -use tokio::sync::{mpsc, RwLock}; +/// DNS tunnel transport — dnstt-style server implementation. +/// +/// Each DNS TXT query from client contains a framed chunk: +/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤120]) +/// +/// Server: +/// 1. Decodes ClientID + fragment from query name +/// 2. Reassembles fragments per (client_id, msg_id) +/// 3. Forwards complete OSTP packet to dispatcher (udp_tx) +/// 4. Waits up to MAX_RESPONSE_DELAY for responses +/// 5. Bundles responses as length-prefixed packets in DNS TXT answer +/// +/// Server → client data in TXT rdata: [len_hi][len_lo][data...]... use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::Arc; use bytes::Bytes; +use tokio::net::UdpSocket; +use tokio::sync::{mpsc, RwLock}; use tokio::time::Duration; -use ostp_core::dns::{DnsPacket, DnsRecordType, decode_domain_to_payload}; +use ostp_core::dns::{base32_decode, DnsPacket, DnsRecordType}; use crate::config::DnsTransportConfig; use crate::UiEvent; +const CLIENT_ID_LEN: usize = 8; +const HEADER_LEN: usize = CLIENT_ID_LEN + 4; // client_id + msg_id(2) + total(1) + idx(1) +/// How long to wait for downstream OSTP data before sending an empty response. +const MAX_RESPONSE_DELAY: Duration = Duration::from_millis(800); +/// Maximum number of response packets to bundle into one DNS answer. +const MAX_RESPONSE_PACKETS: usize = 8; +/// How long to keep per-client reassembly state without activity. +const CLIENT_EXPIRY: Duration = Duration::from_secs(30); + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +struct ClientId([u8; CLIENT_ID_LEN]); + +struct ReassemblyState { + total: u8, + frags: Vec>>, + received: u8, +} + +impl ReassemblyState { + fn new(total: u8) -> Self { + Self { + total, + frags: vec![None; total as usize], + received: 0, + } + } + + fn insert(&mut self, idx: u8, payload: Vec) -> bool { + let idx = idx as usize; + if idx >= self.frags.len() { return false; } + if self.frags[idx].is_none() { + self.frags[idx] = Some(payload); + self.received += 1; + } + self.received >= self.total + } + + fn assemble(self) -> Option> { + let mut out = Vec::new(); + for frag in self.frags { + out.extend_from_slice(&frag?); + } + Some(out) + } +} + +struct ClientState { + /// msg_id → reassembly buffer + reassembly: HashMap, + /// Channel to push pending responses into; DNS handler polls this per-query + #[allow(dead_code)] + resp_tx: mpsc::Sender, + last_seen: std::time::Instant, +} + pub(crate) async fn start_dns_transport_server( config: DnsTransportConfig, udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, @@ -34,80 +102,218 @@ pub(crate) async fn start_dns_transport_server( tracing::info!("DNS Transport listening on {}", listen_addr); let _ = ui_event_tx.send(UiEvent::Log(format!("DNS Transport listening on {}", listen_addr))); - let mut buf = vec![0u8; 65535]; - loop { - match socket.recv_from(&mut buf).await { - Ok((size, peer)) => { - let packet_bytes = buf[..size].to_vec(); - let udp_tx = udp_tx.clone(); - let tcp_map = tcp_map.clone(); - let socket = socket.clone(); - let base_domain = config.domain.clone(); + // Per-client state: ClientId → ClientState + // Access is serialised by a single Mutex so fragments from the same client + // are always reassembled atomically. + let clients: Arc>> = + Arc::new(tokio::sync::Mutex::new(HashMap::new())); - tokio::spawn(async move { - if let Some(dns_req) = DnsPacket::decode(&packet_bytes) { - if dns_req.questions.is_empty() { return; } - let query = &dns_req.questions[0]; - - // Check if it's our target domain and it's a TXT or NULL query - if (query.qtype == DnsRecordType::TXT || query.qtype == DnsRecordType::NULL) && query.name.ends_with(&base_domain) { - // Decode base32 payload - if let Some(payload) = decode_domain_to_payload(&query.name, &base_domain) { - - let (resp_tx, mut resp_rx) = mpsc::channel::(10); - - // Insert into tcp_map so Dispatcher routes responses to us - tcp_map.write().await.insert(peer, resp_tx); - - // Send payload to dispatcher - if udp_tx.send((Bytes::from(payload), peer)).await.is_ok() { - // Wait up to 50ms for any responses - let mut responses = Vec::new(); - - while let Ok(Some(resp)) = tokio::time::timeout(Duration::from_millis(50), resp_rx.recv()).await { - responses.push(resp); - if responses.len() >= 3 { break; } - } - - // Remove from tcp_map - tcp_map.write().await.remove(&peer); - - // Build DNS Answer - let mut dns_resp = DnsPacket::new_response(dns_req.id, &query.name, query.qtype.clone(), vec![]); - dns_resp.answers.clear(); // We'll add our own - - if !responses.is_empty() { - for r in responses { - dns_resp.answers.push(ostp_core::dns::DnsAnswer { - name: query.name.clone(), - rtype: query.qtype.clone(), - rclass: 1, - ttl: 0, - rdata: r.to_vec(), - }); - } - } else { - // Empty response - dns_resp.answers.push(ostp_core::dns::DnsAnswer { - name: query.name.clone(), - rtype: query.qtype.clone(), - rclass: 1, - ttl: 0, - rdata: vec![], - }); - } - - let resp_encoded = dns_resp.encode(); - let _ = socket.send_to(&resp_encoded, peer).await; - } - } - } - } - }); + // Cleanup task: evict stale client state + { + let clients_gc = clients.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(15)).await; + let mut map = clients_gc.lock().await; + map.retain(|_, v| v.last_seen.elapsed() < CLIENT_EXPIRY); } + }); + } + + let base_domain = config.domain.clone(); + let mut buf = vec![0u8; 65535]; + + loop { + let (size, peer) = match socket.recv_from(&mut buf).await { + Ok(v) => v, Err(e) => { tracing::warn!("DNS Transport recv error: {}", e); + continue; + } + }; + + let packet_bytes = buf[..size].to_vec(); + let udp_tx = udp_tx.clone(); + let tcp_map = tcp_map.clone(); + let socket = socket.clone(); + let clients = clients.clone(); + let base_domain = base_domain.clone(); + + tokio::spawn(async move { + handle_dns_query( + packet_bytes, peer, + udp_tx, tcp_map, socket, clients, base_domain, + ).await; + }); + } +} + +async fn handle_dns_query( + packet_bytes: Vec, + peer: SocketAddr, + udp_tx: mpsc::Sender<(Bytes, SocketAddr)>, + tcp_map: Arc>>>, + socket: Arc, + clients: Arc>>, + base_domain: String, +) { + let dns_req = match DnsPacket::decode(&packet_bytes) { + Some(p) => p, + None => { + tracing::debug!("DNS: failed to decode packet from {}", peer); + return; + } + }; + + if dns_req.questions.is_empty() { return; } + let query = &dns_req.questions[0]; + + // Must be TXT query for our subdomain + if query.qtype != DnsRecordType::TXT && query.qtype != DnsRecordType::NULL { return; } + if !query.name.ends_with(&base_domain) { return; } + + // Strip base domain and labels separator to get base32 subdomain + let subdomain = { + let name_lower = query.name.to_lowercase(); + let suffix = format!(".{}", base_domain.to_lowercase()); + let suffix_bare = base_domain.to_lowercase(); + let stripped = if name_lower.ends_with(&suffix) { + &query.name[..name_lower.len() - suffix.len()] + } else if name_lower == suffix_bare { + "" + } else { + return; + }; + // Remove dots (label separators) to get contiguous base32 + stripped.replace('.', "") + }; + + if subdomain.is_empty() { + // Pure poll — no payload + let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), vec![]); + let _ = socket.send_to(&resp, peer).await; + return; + } + + // Base32-decode + let raw = match base32_decode(&subdomain) { + Some(b) => b, + None => { + tracing::debug!("DNS: base32 decode failed from {}", peer); + return; + } + }; + + if raw.len() < HEADER_LEN { + tracing::debug!("DNS: frame too short ({} bytes) from {}", raw.len(), peer); + return; + } + + // Parse header + let client_id = ClientId(raw[..CLIENT_ID_LEN].try_into().unwrap()); + let msg_id = u16::from_be_bytes([raw[8], raw[9]]); + let total_frags = raw[10]; + let frag_idx = raw[11]; + let payload = raw[HEADER_LEN..].to_vec(); + + tracing::trace!("DNS: client={} msg={} frag={}/{} payload={}B", + hex::encode(&client_id.0), msg_id, frag_idx + 1, total_frags, payload.len()); + + // ── Reassembly ──────────────────────────────────────────────────────────── + let complete_packet: Option> = { + let mut map = clients.lock().await; + let state = map.entry(client_id).or_insert_with(|| { + let (resp_tx, _) = mpsc::channel(64); // placeholder, will be replaced below + ClientState { + reassembly: HashMap::new(), + resp_tx, + last_seen: std::time::Instant::now(), + } + }); + state.last_seen = std::time::Instant::now(); + + if total_frags == 0 { + // Empty poll — no data + None + } else if total_frags == 1 && payload.is_empty() { + // Poll with empty payload + None + } else { + let asm = state.reassembly + .entry(msg_id) + .or_insert_with(|| ReassemblyState::new(total_frags)); + + if asm.insert(frag_idx, payload) { + // All fragments received — assemble and remove + let complete = state.reassembly.remove(&msg_id) + .and_then(|s| s.assemble()); + complete + } else { + None + } + } + }; + + // ── Create per-query response channel ──────────────────────────────────── + // We use the peer SocketAddr as the routing key in tcp_map. + // For each query we create a fresh one-shot channel. + let (resp_tx, mut resp_rx) = mpsc::channel::(MAX_RESPONSE_PACKETS); + tcp_map.write().await.insert(peer, resp_tx); + + // ── Forward complete OSTP packet to dispatcher ──────────────────────────── + if let Some(ostp_pkt) = complete_packet { + tracing::debug!("DNS: forwarding {}B OSTP packet from client={} to dispatcher", + ostp_pkt.len(), hex::encode(&client_id.0)); + let _ = udp_tx.send((Bytes::from(ostp_pkt), peer)).await; + } + + // ── Wait for OSTP response(s) ───────────────────────────────────────────── + let mut responses: Vec = Vec::new(); + let deadline = tokio::time::sleep(MAX_RESPONSE_DELAY); + tokio::pin!(deadline); + + loop { + tokio::select! { + _ = &mut deadline => break, + resp = resp_rx.recv() => { + match resp { + Some(r) => { + responses.push(r); + if responses.len() >= MAX_RESPONSE_PACKETS { break; } + } + None => break, + } } } } + + tcp_map.write().await.remove(&peer); + + // ── Build DNS TXT response ──────────────────────────────────────────────── + // Bundle all response packets as length-prefixed data in TXT rdata: + // [len_hi][len_lo][data...]... + let mut rdata: Vec = Vec::new(); + for r in &responses { + let len = r.len() as u16; + rdata.push((len >> 8) as u8); + rdata.push((len & 0xFF) as u8); + rdata.extend_from_slice(r); + } + + tracing::trace!("DNS: responding to {} with {} OSTP packets ({} bytes rdata)", + peer, responses.len(), rdata.len()); + + let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), rdata); + let _ = socket.send_to(&resp, peer).await; +} + +/// Build a DNS response packet with the given TXT rdata. +fn build_dns_response( + req: &DnsPacket, + name: &str, + rtype: DnsRecordType, + rdata: Vec, +) -> Vec { + let resp = DnsPacket::new_response(req.id, name, rtype, rdata); + resp.encode() }