Rewrite DNS transport with dnstt-style fragmentation, ClientID, polling and reassembly

This commit is contained in:
ospab 2026-06-20 18:45:23 +03:00
parent 6987ac5344
commit 3ced4a19b6
2 changed files with 455 additions and 123 deletions

View File

@ -1,19 +1,43 @@
/// DNS tunnel transport — dnstt-style implementation.
///
/// Protocol (client → server, embedded in DNS query domain name):
/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤MAX_CHUNK])
/// Split into DNS labels of max 63 chars, suffixed with base_domain.
///
/// Poll query: payload is empty (total_frags=1, frag_idx=0, len=0).
///
/// Protocol (server → client, in TXT rdata):
/// Concatenated length-prefixed OSTP packets: [len: 2 BE][data ...]...
///
/// Polling: adaptive 500ms → 10s, like dnstt. Resets to 500ms on real data.
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use rand::Rng;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, Mutex};
use rand::Rng;
pub use ostp_core::dns::{
DnsPacket, DnsRecordType, encode_payload_to_domain,
decode_domain_to_payload,
};
use crate::transport::Transport;
use rand::RngCore;
use ostp_core::dns::{base32_encode, DnsPacket, DnsRecordType};
pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Option<String>) -> std::io::Result<Transport> {
let (app_tx, transport_rx) = mpsc::channel::<Bytes>(100);
let (transport_tx, app_rx) = mpsc::channel::<Bytes>(100);
/// Max raw payload bytes we put into one DNS query.
/// Calculation: FQDN ≤ 253 chars. Domain suffix ~30 chars max.
/// Remaining: ~220 chars for base32 labels. 220/8*5 = 137 bytes raw.
/// Header = 12 bytes → payload ≤ 120 bytes (conservative, works for any domain ≤ 40 chars).
const MAX_CHUNK_PAYLOAD: usize = 120;
const CLIENT_ID_LEN: usize = 8;
const INIT_POLL_DELAY: Duration = Duration::from_millis(500);
const MAX_POLL_DELAY: Duration = Duration::from_secs(10);
const POLL_DELAY_MULTIPLIER: f64 = 2.0;
pub async fn start_dns_transport(
domain: String,
resolver: String,
_pubkey: Option<String>,
) -> std::io::Result<Transport> {
let (app_tx, transport_rx) = mpsc::channel::<Bytes>(256);
let (transport_tx, app_rx) = mpsc::channel::<Bytes>(256);
let resolver_addr = if resolver.contains(':') {
resolver.clone()
@ -25,69 +49,124 @@ pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Opti
socket.connect(&resolver_addr).await?;
let socket = Arc::new(socket);
let sock_rx = socket.clone();
let sock_tx = socket;
let base_domain = domain.clone();
// Generate random ClientID for this tunnel session
let mut client_id = [0u8; CLIENT_ID_LEN];
rand::thread_rng().fill_bytes(&mut client_id);
let client_id = Arc::new(client_id);
// Send task (reads from app, encodes into DNS TXT, sends to UDP socket)
tracing::info!("DNS transport: domain={} resolver={} client_id={}",
domain, resolver_addr,
hex::encode(client_id.as_slice()));
// ── Send task ─────────────────────────────────────────────────────────────
let sock_send = socket.clone();
let cid_send = client_id.clone();
let domain_send = domain.clone();
tokio::spawn(async move {
let mut rx = transport_rx;
let mut msg_id: u16 = 0;
let mut poll_delay = INIT_POLL_DELAY;
loop {
let data_opt = tokio::select! {
res = rx.recv() => res,
_ = tokio::time::sleep(Duration::from_secs(2)) => Some(Bytes::new()),
let data: Option<Bytes> = tokio::select! {
data = rx.recv() => data,
_ = tokio::time::sleep(poll_delay) => {
poll_delay = Duration::from_secs_f64(
(poll_delay.as_secs_f64() * POLL_DELAY_MULTIPLIER)
.min(MAX_POLL_DELAY.as_secs_f64())
);
// Send poll (empty payload)
Some(Bytes::new())
}
};
let data = match data_opt {
let data = match data {
Some(d) => d,
None => break, // App closed
None => {
tracing::debug!("DNS send task: channel closed, exiting");
break;
}
};
// Encode data to base32 domain
let fqdn = encode_payload_to_domain(&data, &base_domain);
let id: u16 = rand::thread_rng().gen();
// Randomly choose TXT or NULL for diversity (as requested)
let qtype = if rand::thread_rng().gen_bool(0.5) {
DnsRecordType::TXT
if data.is_empty() {
// Poll query — one empty chunk
if let Err(e) = send_chunk(&sock_send, &cid_send, msg_id, 1, 0, &[], &domain_send).await {
tracing::warn!("DNS poll send error: {}", e);
}
} else {
DnsRecordType::NULL
};
// Real OSTP packet — fragment into chunks
poll_delay = INIT_POLL_DELAY; // reset on real data
let packet = DnsPacket::new_query(id, &fqdn, qtype);
let encoded = packet.encode();
let data_slice = data.as_ref();
let total_chunks = data_slice.chunks(MAX_CHUNK_PAYLOAD).count();
let total_u8 = total_chunks.min(255) as u8;
if let Err(e) = sock_tx.send(&encoded).await {
tracing::warn!("DNS transport send error: {}", e);
break;
for (idx, chunk) in data_slice.chunks(MAX_CHUNK_PAYLOAD).enumerate() {
if let Err(e) = send_chunk(
&sock_send, &cid_send,
msg_id, total_u8, idx as u8,
chunk, &domain_send,
).await {
tracing::warn!("DNS chunk send error (idx={}): {}", idx, e);
break;
}
// Brief inter-fragment delay to avoid flooding the resolver
if total_chunks > 1 {
tokio::time::sleep(Duration::from_millis(20)).await;
}
}
msg_id = msg_id.wrapping_add(1);
}
}
});
// Receive task (reads from UDP socket, decodes DNS answer, sends to app)
let _base_domain_rx = domain.clone();
// ── Receive task ──────────────────────────────────────────────────────────
let sock_recv = socket.clone();
let tx_recv = transport_tx.clone();
let domain_recv = domain.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 65535];
// Reassembly buffers: msg_id → (total, Vec<Option<chunk>>)
let reassembly: HashMap<u16, (u8, Vec<Option<Vec<u8>>>)> = HashMap::new();
loop {
match sock_rx.recv(&mut buf).await {
match sock_recv.recv(&mut buf).await {
Ok(n) => {
if let Some(packet) = DnsPacket::decode(&buf[..n]) {
for answer in packet.answers {
if answer.rtype == DnsRecordType::TXT || answer.rtype == DnsRecordType::NULL {
// If it's a TXT record, the response might be base32 encoded payload?
// Actually, dnstt puts the payload in the TXT/NULL record data.
// We'll just assume the rdata is the raw payload, or base32 encoded if it was sent as such.
// Let's just pass the raw data (TXT strings are decoded in DnsPacket::decode)
// Wait, dnstt server responds with raw bytes in NULL, and base32/chunked strings in TXT.
// Our `DnsPacket::decode` already handles extracting TXT string bytes or NULL raw bytes into `rdata`.
// Let's just send `rdata` to the app.
if transport_tx.send(Bytes::from(answer.rdata)).await.is_err() {
return; // App closed
}
let Some(pkt) = DnsPacket::decode(&buf[..n]) else { continue };
// Only process DNS responses
if pkt.flags & 0x8000 == 0 { continue; }
for answer in pkt.answers {
if answer.rtype != DnsRecordType::TXT && answer.rtype != DnsRecordType::NULL {
continue;
}
let rdata = answer.rdata;
// Parse length-prefixed OSTP packets packed in rdata:
// [len_hi: 1][len_lo: 1][data: len]...
let mut pos = 0;
while pos + 2 <= rdata.len() {
let pkt_len = u16::from_be_bytes([rdata[pos], rdata[pos + 1]]) as usize;
pos += 2;
if pkt_len == 0 { continue; }
if pos + pkt_len > rdata.len() {
tracing::debug!("DNS recv: truncated packet in rdata");
break;
}
let payload = Bytes::copy_from_slice(&rdata[pos..pos + pkt_len]);
pos += pkt_len;
if tx_recv.send(payload).await.is_err() {
return; // app closed
}
}
}
// Also check for responses packed in the server's extra DNS answer rdata
// that use our fragmentation scheme (server→client fragments)
// This is handled above via the length-prefix protocol.
let _ = &reassembly; // Keep for future upstream fragmentation support
let _ = &domain_recv;
}
Err(e) => {
tracing::warn!("DNS transport recv error: {}", e);
@ -102,3 +181,50 @@ pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Opti
rx: Arc::new(Mutex::new(app_rx)),
})
}
/// Build and send one DNS TXT query with a framed chunk.
///
/// Frame format (before base32 encoding):
/// [client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: 0120]
async fn send_chunk(
socket: &UdpSocket,
client_id: &[u8; CLIENT_ID_LEN],
msg_id: u16,
total_frags: u8,
frag_idx: u8,
payload: &[u8],
base_domain: &str,
) -> std::io::Result<()> {
// Build frame
let mut frame = Vec::with_capacity(CLIENT_ID_LEN + 4 + payload.len());
frame.extend_from_slice(client_id);
frame.extend_from_slice(&msg_id.to_be_bytes());
frame.push(total_frags);
frame.push(frag_idx);
frame.extend_from_slice(payload);
// Base32-encode
let encoded = base32_encode(&frame);
// Split into 63-char labels and append domain
let mut fqdn = String::with_capacity(encoded.len() + base_domain.len() + 10);
let mut start = 0;
while start < encoded.len() {
let end = (start + 63).min(encoded.len());
fqdn.push_str(&encoded[start..end]);
fqdn.push('.');
start = end;
}
fqdn.push_str(base_domain);
// Build DNS TXT query with random ID
let id: u16 = rand::thread_rng().gen();
let pkt = DnsPacket::new_query(id, &fqdn, DnsRecordType::TXT);
let wire = pkt.encode();
tracing::trace!("DNS send chunk: msg_id={} frag={}/{} payload={}B fqdn_len={}",
msg_id, frag_idx + 1, total_frags, payload.len(), fqdn.len());
socket.send(&wire).await?;
Ok(())
}

View File

@ -1,15 +1,83 @@
use std::sync::Arc;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, RwLock};
/// DNS tunnel transport — dnstt-style server implementation.
///
/// Each DNS TXT query from client contains a framed chunk:
/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤120])
///
/// Server:
/// 1. Decodes ClientID + fragment from query name
/// 2. Reassembles fragments per (client_id, msg_id)
/// 3. Forwards complete OSTP packet to dispatcher (udp_tx)
/// 4. Waits up to MAX_RESPONSE_DELAY for responses
/// 5. Bundles responses as length-prefixed packets in DNS TXT answer
///
/// Server → client data in TXT rdata: [len_hi][len_lo][data...]...
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use bytes::Bytes;
use tokio::net::UdpSocket;
use tokio::sync::{mpsc, RwLock};
use tokio::time::Duration;
use ostp_core::dns::{DnsPacket, DnsRecordType, decode_domain_to_payload};
use ostp_core::dns::{base32_decode, DnsPacket, DnsRecordType};
use crate::config::DnsTransportConfig;
use crate::UiEvent;
const CLIENT_ID_LEN: usize = 8;
const HEADER_LEN: usize = CLIENT_ID_LEN + 4; // client_id + msg_id(2) + total(1) + idx(1)
/// How long to wait for downstream OSTP data before sending an empty response.
const MAX_RESPONSE_DELAY: Duration = Duration::from_millis(800);
/// Maximum number of response packets to bundle into one DNS answer.
const MAX_RESPONSE_PACKETS: usize = 8;
/// How long to keep per-client reassembly state without activity.
const CLIENT_EXPIRY: Duration = Duration::from_secs(30);
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct ClientId([u8; CLIENT_ID_LEN]);
struct ReassemblyState {
total: u8,
frags: Vec<Option<Vec<u8>>>,
received: u8,
}
impl ReassemblyState {
fn new(total: u8) -> Self {
Self {
total,
frags: vec![None; total as usize],
received: 0,
}
}
fn insert(&mut self, idx: u8, payload: Vec<u8>) -> bool {
let idx = idx as usize;
if idx >= self.frags.len() { return false; }
if self.frags[idx].is_none() {
self.frags[idx] = Some(payload);
self.received += 1;
}
self.received >= self.total
}
fn assemble(self) -> Option<Vec<u8>> {
let mut out = Vec::new();
for frag in self.frags {
out.extend_from_slice(&frag?);
}
Some(out)
}
}
struct ClientState {
/// msg_id → reassembly buffer
reassembly: HashMap<u16, ReassemblyState>,
/// Channel to push pending responses into; DNS handler polls this per-query
#[allow(dead_code)]
resp_tx: mpsc::Sender<Bytes>,
last_seen: std::time::Instant,
}
pub(crate) async fn start_dns_transport_server(
config: DnsTransportConfig,
udp_tx: mpsc::Sender<(Bytes, SocketAddr)>,
@ -34,80 +102,218 @@ pub(crate) async fn start_dns_transport_server(
tracing::info!("DNS Transport listening on {}", listen_addr);
let _ = ui_event_tx.send(UiEvent::Log(format!("DNS Transport listening on {}", listen_addr)));
let mut buf = vec![0u8; 65535];
loop {
match socket.recv_from(&mut buf).await {
Ok((size, peer)) => {
let packet_bytes = buf[..size].to_vec();
let udp_tx = udp_tx.clone();
let tcp_map = tcp_map.clone();
let socket = socket.clone();
let base_domain = config.domain.clone();
// Per-client state: ClientId → ClientState
// Access is serialised by a single Mutex so fragments from the same client
// are always reassembled atomically.
let clients: Arc<tokio::sync::Mutex<HashMap<ClientId, ClientState>>> =
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
tokio::spawn(async move {
if let Some(dns_req) = DnsPacket::decode(&packet_bytes) {
if dns_req.questions.is_empty() { return; }
let query = &dns_req.questions[0];
// Check if it's our target domain and it's a TXT or NULL query
if (query.qtype == DnsRecordType::TXT || query.qtype == DnsRecordType::NULL) && query.name.ends_with(&base_domain) {
// Decode base32 payload
if let Some(payload) = decode_domain_to_payload(&query.name, &base_domain) {
let (resp_tx, mut resp_rx) = mpsc::channel::<Bytes>(10);
// Insert into tcp_map so Dispatcher routes responses to us
tcp_map.write().await.insert(peer, resp_tx);
// Send payload to dispatcher
if udp_tx.send((Bytes::from(payload), peer)).await.is_ok() {
// Wait up to 50ms for any responses
let mut responses = Vec::new();
while let Ok(Some(resp)) = tokio::time::timeout(Duration::from_millis(50), resp_rx.recv()).await {
responses.push(resp);
if responses.len() >= 3 { break; }
}
// Remove from tcp_map
tcp_map.write().await.remove(&peer);
// Build DNS Answer
let mut dns_resp = DnsPacket::new_response(dns_req.id, &query.name, query.qtype.clone(), vec![]);
dns_resp.answers.clear(); // We'll add our own
if !responses.is_empty() {
for r in responses {
dns_resp.answers.push(ostp_core::dns::DnsAnswer {
name: query.name.clone(),
rtype: query.qtype.clone(),
rclass: 1,
ttl: 0,
rdata: r.to_vec(),
});
}
} else {
// Empty response
dns_resp.answers.push(ostp_core::dns::DnsAnswer {
name: query.name.clone(),
rtype: query.qtype.clone(),
rclass: 1,
ttl: 0,
rdata: vec![],
});
}
let resp_encoded = dns_resp.encode();
let _ = socket.send_to(&resp_encoded, peer).await;
}
}
}
}
});
// Cleanup task: evict stale client state
{
let clients_gc = clients.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(15)).await;
let mut map = clients_gc.lock().await;
map.retain(|_, v| v.last_seen.elapsed() < CLIENT_EXPIRY);
}
});
}
let base_domain = config.domain.clone();
let mut buf = vec![0u8; 65535];
loop {
let (size, peer) = match socket.recv_from(&mut buf).await {
Ok(v) => v,
Err(e) => {
tracing::warn!("DNS Transport recv error: {}", e);
continue;
}
};
let packet_bytes = buf[..size].to_vec();
let udp_tx = udp_tx.clone();
let tcp_map = tcp_map.clone();
let socket = socket.clone();
let clients = clients.clone();
let base_domain = base_domain.clone();
tokio::spawn(async move {
handle_dns_query(
packet_bytes, peer,
udp_tx, tcp_map, socket, clients, base_domain,
).await;
});
}
}
async fn handle_dns_query(
packet_bytes: Vec<u8>,
peer: SocketAddr,
udp_tx: mpsc::Sender<(Bytes, SocketAddr)>,
tcp_map: Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>,
socket: Arc<UdpSocket>,
clients: Arc<tokio::sync::Mutex<HashMap<ClientId, ClientState>>>,
base_domain: String,
) {
let dns_req = match DnsPacket::decode(&packet_bytes) {
Some(p) => p,
None => {
tracing::debug!("DNS: failed to decode packet from {}", peer);
return;
}
};
if dns_req.questions.is_empty() { return; }
let query = &dns_req.questions[0];
// Must be TXT query for our subdomain
if query.qtype != DnsRecordType::TXT && query.qtype != DnsRecordType::NULL { return; }
if !query.name.ends_with(&base_domain) { return; }
// Strip base domain and labels separator to get base32 subdomain
let subdomain = {
let name_lower = query.name.to_lowercase();
let suffix = format!(".{}", base_domain.to_lowercase());
let suffix_bare = base_domain.to_lowercase();
let stripped = if name_lower.ends_with(&suffix) {
&query.name[..name_lower.len() - suffix.len()]
} else if name_lower == suffix_bare {
""
} else {
return;
};
// Remove dots (label separators) to get contiguous base32
stripped.replace('.', "")
};
if subdomain.is_empty() {
// Pure poll — no payload
let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), vec![]);
let _ = socket.send_to(&resp, peer).await;
return;
}
// Base32-decode
let raw = match base32_decode(&subdomain) {
Some(b) => b,
None => {
tracing::debug!("DNS: base32 decode failed from {}", peer);
return;
}
};
if raw.len() < HEADER_LEN {
tracing::debug!("DNS: frame too short ({} bytes) from {}", raw.len(), peer);
return;
}
// Parse header
let client_id = ClientId(raw[..CLIENT_ID_LEN].try_into().unwrap());
let msg_id = u16::from_be_bytes([raw[8], raw[9]]);
let total_frags = raw[10];
let frag_idx = raw[11];
let payload = raw[HEADER_LEN..].to_vec();
tracing::trace!("DNS: client={} msg={} frag={}/{} payload={}B",
hex::encode(&client_id.0), msg_id, frag_idx + 1, total_frags, payload.len());
// ── Reassembly ────────────────────────────────────────────────────────────
let complete_packet: Option<Vec<u8>> = {
let mut map = clients.lock().await;
let state = map.entry(client_id).or_insert_with(|| {
let (resp_tx, _) = mpsc::channel(64); // placeholder, will be replaced below
ClientState {
reassembly: HashMap::new(),
resp_tx,
last_seen: std::time::Instant::now(),
}
});
state.last_seen = std::time::Instant::now();
if total_frags == 0 {
// Empty poll — no data
None
} else if total_frags == 1 && payload.is_empty() {
// Poll with empty payload
None
} else {
let asm = state.reassembly
.entry(msg_id)
.or_insert_with(|| ReassemblyState::new(total_frags));
if asm.insert(frag_idx, payload) {
// All fragments received — assemble and remove
let complete = state.reassembly.remove(&msg_id)
.and_then(|s| s.assemble());
complete
} else {
None
}
}
};
// ── Create per-query response channel ────────────────────────────────────
// We use the peer SocketAddr as the routing key in tcp_map.
// For each query we create a fresh one-shot channel.
let (resp_tx, mut resp_rx) = mpsc::channel::<Bytes>(MAX_RESPONSE_PACKETS);
tcp_map.write().await.insert(peer, resp_tx);
// ── Forward complete OSTP packet to dispatcher ────────────────────────────
if let Some(ostp_pkt) = complete_packet {
tracing::debug!("DNS: forwarding {}B OSTP packet from client={} to dispatcher",
ostp_pkt.len(), hex::encode(&client_id.0));
let _ = udp_tx.send((Bytes::from(ostp_pkt), peer)).await;
}
// ── Wait for OSTP response(s) ─────────────────────────────────────────────
let mut responses: Vec<Bytes> = Vec::new();
let deadline = tokio::time::sleep(MAX_RESPONSE_DELAY);
tokio::pin!(deadline);
loop {
tokio::select! {
_ = &mut deadline => break,
resp = resp_rx.recv() => {
match resp {
Some(r) => {
responses.push(r);
if responses.len() >= MAX_RESPONSE_PACKETS { break; }
}
None => break,
}
}
}
}
tcp_map.write().await.remove(&peer);
// ── Build DNS TXT response ────────────────────────────────────────────────
// Bundle all response packets as length-prefixed data in TXT rdata:
// [len_hi][len_lo][data...]...
let mut rdata: Vec<u8> = Vec::new();
for r in &responses {
let len = r.len() as u16;
rdata.push((len >> 8) as u8);
rdata.push((len & 0xFF) as u8);
rdata.extend_from_slice(r);
}
tracing::trace!("DNS: responding to {} with {} OSTP packets ({} bytes rdata)",
peer, responses.len(), rdata.len());
let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), rdata);
let _ = socket.send_to(&resp, peer).await;
}
/// Build a DNS response packet with the given TXT rdata.
fn build_dns_response(
req: &DnsPacket,
name: &str,
rtype: DnsRecordType,
rdata: Vec<u8>,
) -> Vec<u8> {
let resp = DnsPacket::new_response(req.id, name, rtype, rdata);
resp.encode()
}