mirror of https://github.com/ospab/ostp.git
Rewrite DNS transport with dnstt-style fragmentation, ClientID, polling and reassembly
This commit is contained in:
parent
6987ac5344
commit
3ced4a19b6
|
|
@ -1,19 +1,43 @@
|
|||
/// DNS tunnel transport — dnstt-style implementation.
|
||||
///
|
||||
/// Protocol (client → server, embedded in DNS query domain name):
|
||||
/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤MAX_CHUNK])
|
||||
/// Split into DNS labels of max 63 chars, suffixed with base_domain.
|
||||
///
|
||||
/// Poll query: payload is empty (total_frags=1, frag_idx=0, len=0).
|
||||
///
|
||||
/// Protocol (server → client, in TXT rdata):
|
||||
/// Concatenated length-prefixed OSTP packets: [len: 2 BE][data ...]...
|
||||
///
|
||||
/// Polling: adaptive 500ms → 10s, like dnstt. Resets to 500ms on real data.
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use bytes::Bytes;
|
||||
use rand::Rng;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use rand::Rng;
|
||||
|
||||
pub use ostp_core::dns::{
|
||||
DnsPacket, DnsRecordType, encode_payload_to_domain,
|
||||
decode_domain_to_payload,
|
||||
};
|
||||
use crate::transport::Transport;
|
||||
use rand::RngCore;
|
||||
use ostp_core::dns::{base32_encode, DnsPacket, DnsRecordType};
|
||||
|
||||
pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Option<String>) -> std::io::Result<Transport> {
|
||||
let (app_tx, transport_rx) = mpsc::channel::<Bytes>(100);
|
||||
let (transport_tx, app_rx) = mpsc::channel::<Bytes>(100);
|
||||
/// Max raw payload bytes we put into one DNS query.
|
||||
/// Calculation: FQDN ≤ 253 chars. Domain suffix ~30 chars max.
|
||||
/// Remaining: ~220 chars for base32 labels. 220/8*5 = 137 bytes raw.
|
||||
/// Header = 12 bytes → payload ≤ 120 bytes (conservative, works for any domain ≤ 40 chars).
|
||||
const MAX_CHUNK_PAYLOAD: usize = 120;
|
||||
const CLIENT_ID_LEN: usize = 8;
|
||||
const INIT_POLL_DELAY: Duration = Duration::from_millis(500);
|
||||
const MAX_POLL_DELAY: Duration = Duration::from_secs(10);
|
||||
const POLL_DELAY_MULTIPLIER: f64 = 2.0;
|
||||
|
||||
pub async fn start_dns_transport(
|
||||
domain: String,
|
||||
resolver: String,
|
||||
_pubkey: Option<String>,
|
||||
) -> std::io::Result<Transport> {
|
||||
let (app_tx, transport_rx) = mpsc::channel::<Bytes>(256);
|
||||
let (transport_tx, app_rx) = mpsc::channel::<Bytes>(256);
|
||||
|
||||
let resolver_addr = if resolver.contains(':') {
|
||||
resolver.clone()
|
||||
|
|
@ -25,69 +49,124 @@ pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Opti
|
|||
socket.connect(&resolver_addr).await?;
|
||||
let socket = Arc::new(socket);
|
||||
|
||||
let sock_rx = socket.clone();
|
||||
let sock_tx = socket;
|
||||
let base_domain = domain.clone();
|
||||
// Generate random ClientID for this tunnel session
|
||||
let mut client_id = [0u8; CLIENT_ID_LEN];
|
||||
rand::thread_rng().fill_bytes(&mut client_id);
|
||||
let client_id = Arc::new(client_id);
|
||||
|
||||
// Send task (reads from app, encodes into DNS TXT, sends to UDP socket)
|
||||
tracing::info!("DNS transport: domain={} resolver={} client_id={}",
|
||||
domain, resolver_addr,
|
||||
hex::encode(client_id.as_slice()));
|
||||
|
||||
// ── Send task ─────────────────────────────────────────────────────────────
|
||||
let sock_send = socket.clone();
|
||||
let cid_send = client_id.clone();
|
||||
let domain_send = domain.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut rx = transport_rx;
|
||||
let mut msg_id: u16 = 0;
|
||||
let mut poll_delay = INIT_POLL_DELAY;
|
||||
|
||||
loop {
|
||||
let data_opt = tokio::select! {
|
||||
res = rx.recv() => res,
|
||||
_ = tokio::time::sleep(Duration::from_secs(2)) => Some(Bytes::new()),
|
||||
let data: Option<Bytes> = tokio::select! {
|
||||
data = rx.recv() => data,
|
||||
_ = tokio::time::sleep(poll_delay) => {
|
||||
poll_delay = Duration::from_secs_f64(
|
||||
(poll_delay.as_secs_f64() * POLL_DELAY_MULTIPLIER)
|
||||
.min(MAX_POLL_DELAY.as_secs_f64())
|
||||
);
|
||||
// Send poll (empty payload)
|
||||
Some(Bytes::new())
|
||||
}
|
||||
};
|
||||
|
||||
let data = match data_opt {
|
||||
|
||||
let data = match data {
|
||||
Some(d) => d,
|
||||
None => break, // App closed
|
||||
None => {
|
||||
tracing::debug!("DNS send task: channel closed, exiting");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Encode data to base32 domain
|
||||
let fqdn = encode_payload_to_domain(&data, &base_domain);
|
||||
let id: u16 = rand::thread_rng().gen();
|
||||
|
||||
// Randomly choose TXT or NULL for diversity (as requested)
|
||||
let qtype = if rand::thread_rng().gen_bool(0.5) {
|
||||
DnsRecordType::TXT
|
||||
if data.is_empty() {
|
||||
// Poll query — one empty chunk
|
||||
if let Err(e) = send_chunk(&sock_send, &cid_send, msg_id, 1, 0, &[], &domain_send).await {
|
||||
tracing::warn!("DNS poll send error: {}", e);
|
||||
}
|
||||
} else {
|
||||
DnsRecordType::NULL
|
||||
};
|
||||
// Real OSTP packet — fragment into chunks
|
||||
poll_delay = INIT_POLL_DELAY; // reset on real data
|
||||
|
||||
let packet = DnsPacket::new_query(id, &fqdn, qtype);
|
||||
let encoded = packet.encode();
|
||||
let data_slice = data.as_ref();
|
||||
let total_chunks = data_slice.chunks(MAX_CHUNK_PAYLOAD).count();
|
||||
let total_u8 = total_chunks.min(255) as u8;
|
||||
|
||||
if let Err(e) = sock_tx.send(&encoded).await {
|
||||
tracing::warn!("DNS transport send error: {}", e);
|
||||
break;
|
||||
for (idx, chunk) in data_slice.chunks(MAX_CHUNK_PAYLOAD).enumerate() {
|
||||
if let Err(e) = send_chunk(
|
||||
&sock_send, &cid_send,
|
||||
msg_id, total_u8, idx as u8,
|
||||
chunk, &domain_send,
|
||||
).await {
|
||||
tracing::warn!("DNS chunk send error (idx={}): {}", idx, e);
|
||||
break;
|
||||
}
|
||||
// Brief inter-fragment delay to avoid flooding the resolver
|
||||
if total_chunks > 1 {
|
||||
tokio::time::sleep(Duration::from_millis(20)).await;
|
||||
}
|
||||
}
|
||||
msg_id = msg_id.wrapping_add(1);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Receive task (reads from UDP socket, decodes DNS answer, sends to app)
|
||||
let _base_domain_rx = domain.clone();
|
||||
// ── Receive task ──────────────────────────────────────────────────────────
|
||||
let sock_recv = socket.clone();
|
||||
let tx_recv = transport_tx.clone();
|
||||
let domain_recv = domain.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 65535];
|
||||
// Reassembly buffers: msg_id → (total, Vec<Option<chunk>>)
|
||||
let reassembly: HashMap<u16, (u8, Vec<Option<Vec<u8>>>)> = HashMap::new();
|
||||
|
||||
loop {
|
||||
match sock_rx.recv(&mut buf).await {
|
||||
match sock_recv.recv(&mut buf).await {
|
||||
Ok(n) => {
|
||||
if let Some(packet) = DnsPacket::decode(&buf[..n]) {
|
||||
for answer in packet.answers {
|
||||
if answer.rtype == DnsRecordType::TXT || answer.rtype == DnsRecordType::NULL {
|
||||
// If it's a TXT record, the response might be base32 encoded payload?
|
||||
// Actually, dnstt puts the payload in the TXT/NULL record data.
|
||||
// We'll just assume the rdata is the raw payload, or base32 encoded if it was sent as such.
|
||||
// Let's just pass the raw data (TXT strings are decoded in DnsPacket::decode)
|
||||
|
||||
// Wait, dnstt server responds with raw bytes in NULL, and base32/chunked strings in TXT.
|
||||
// Our `DnsPacket::decode` already handles extracting TXT string bytes or NULL raw bytes into `rdata`.
|
||||
// Let's just send `rdata` to the app.
|
||||
if transport_tx.send(Bytes::from(answer.rdata)).await.is_err() {
|
||||
return; // App closed
|
||||
}
|
||||
let Some(pkt) = DnsPacket::decode(&buf[..n]) else { continue };
|
||||
|
||||
// Only process DNS responses
|
||||
if pkt.flags & 0x8000 == 0 { continue; }
|
||||
|
||||
for answer in pkt.answers {
|
||||
if answer.rtype != DnsRecordType::TXT && answer.rtype != DnsRecordType::NULL {
|
||||
continue;
|
||||
}
|
||||
let rdata = answer.rdata;
|
||||
// Parse length-prefixed OSTP packets packed in rdata:
|
||||
// [len_hi: 1][len_lo: 1][data: len]...
|
||||
let mut pos = 0;
|
||||
while pos + 2 <= rdata.len() {
|
||||
let pkt_len = u16::from_be_bytes([rdata[pos], rdata[pos + 1]]) as usize;
|
||||
pos += 2;
|
||||
if pkt_len == 0 { continue; }
|
||||
if pos + pkt_len > rdata.len() {
|
||||
tracing::debug!("DNS recv: truncated packet in rdata");
|
||||
break;
|
||||
}
|
||||
let payload = Bytes::copy_from_slice(&rdata[pos..pos + pkt_len]);
|
||||
pos += pkt_len;
|
||||
|
||||
if tx_recv.send(payload).await.is_err() {
|
||||
return; // app closed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check for responses packed in the server's extra DNS answer rdata
|
||||
// that use our fragmentation scheme (server→client fragments)
|
||||
// This is handled above via the length-prefix protocol.
|
||||
let _ = &reassembly; // Keep for future upstream fragmentation support
|
||||
let _ = &domain_recv;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("DNS transport recv error: {}", e);
|
||||
|
|
@ -102,3 +181,50 @@ pub async fn start_dns_transport(domain: String, resolver: String, _pubkey: Opti
|
|||
rx: Arc::new(Mutex::new(app_rx)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Build and send one DNS TXT query with a framed chunk.
|
||||
///
|
||||
/// Frame format (before base32 encoding):
|
||||
/// [client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: 0–120]
|
||||
async fn send_chunk(
|
||||
socket: &UdpSocket,
|
||||
client_id: &[u8; CLIENT_ID_LEN],
|
||||
msg_id: u16,
|
||||
total_frags: u8,
|
||||
frag_idx: u8,
|
||||
payload: &[u8],
|
||||
base_domain: &str,
|
||||
) -> std::io::Result<()> {
|
||||
// Build frame
|
||||
let mut frame = Vec::with_capacity(CLIENT_ID_LEN + 4 + payload.len());
|
||||
frame.extend_from_slice(client_id);
|
||||
frame.extend_from_slice(&msg_id.to_be_bytes());
|
||||
frame.push(total_frags);
|
||||
frame.push(frag_idx);
|
||||
frame.extend_from_slice(payload);
|
||||
|
||||
// Base32-encode
|
||||
let encoded = base32_encode(&frame);
|
||||
|
||||
// Split into 63-char labels and append domain
|
||||
let mut fqdn = String::with_capacity(encoded.len() + base_domain.len() + 10);
|
||||
let mut start = 0;
|
||||
while start < encoded.len() {
|
||||
let end = (start + 63).min(encoded.len());
|
||||
fqdn.push_str(&encoded[start..end]);
|
||||
fqdn.push('.');
|
||||
start = end;
|
||||
}
|
||||
fqdn.push_str(base_domain);
|
||||
|
||||
// Build DNS TXT query with random ID
|
||||
let id: u16 = rand::thread_rng().gen();
|
||||
let pkt = DnsPacket::new_query(id, &fqdn, DnsRecordType::TXT);
|
||||
let wire = pkt.encode();
|
||||
|
||||
tracing::trace!("DNS send chunk: msg_id={} frag={}/{} payload={}B fqdn_len={}",
|
||||
msg_id, frag_idx + 1, total_frags, payload.len(), fqdn.len());
|
||||
|
||||
socket.send(&wire).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,83 @@
|
|||
use std::sync::Arc;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
/// DNS tunnel transport — dnstt-style server implementation.
|
||||
///
|
||||
/// Each DNS TXT query from client contains a framed chunk:
|
||||
/// Base32([client_id: 8][msg_id: 2 BE][total_frags: 1][frag_idx: 1][payload: ≤120])
|
||||
///
|
||||
/// Server:
|
||||
/// 1. Decodes ClientID + fragment from query name
|
||||
/// 2. Reassembles fragments per (client_id, msg_id)
|
||||
/// 3. Forwards complete OSTP packet to dispatcher (udp_tx)
|
||||
/// 4. Waits up to MAX_RESPONSE_DELAY for responses
|
||||
/// 5. Bundles responses as length-prefixed packets in DNS TXT answer
|
||||
///
|
||||
/// Server → client data in TXT rdata: [len_hi][len_lo][data...]...
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::time::Duration;
|
||||
|
||||
use ostp_core::dns::{DnsPacket, DnsRecordType, decode_domain_to_payload};
|
||||
use ostp_core::dns::{base32_decode, DnsPacket, DnsRecordType};
|
||||
use crate::config::DnsTransportConfig;
|
||||
use crate::UiEvent;
|
||||
|
||||
const CLIENT_ID_LEN: usize = 8;
|
||||
const HEADER_LEN: usize = CLIENT_ID_LEN + 4; // client_id + msg_id(2) + total(1) + idx(1)
|
||||
/// How long to wait for downstream OSTP data before sending an empty response.
|
||||
const MAX_RESPONSE_DELAY: Duration = Duration::from_millis(800);
|
||||
/// Maximum number of response packets to bundle into one DNS answer.
|
||||
const MAX_RESPONSE_PACKETS: usize = 8;
|
||||
/// How long to keep per-client reassembly state without activity.
|
||||
const CLIENT_EXPIRY: Duration = Duration::from_secs(30);
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
struct ClientId([u8; CLIENT_ID_LEN]);
|
||||
|
||||
struct ReassemblyState {
|
||||
total: u8,
|
||||
frags: Vec<Option<Vec<u8>>>,
|
||||
received: u8,
|
||||
}
|
||||
|
||||
impl ReassemblyState {
|
||||
fn new(total: u8) -> Self {
|
||||
Self {
|
||||
total,
|
||||
frags: vec![None; total as usize],
|
||||
received: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn insert(&mut self, idx: u8, payload: Vec<u8>) -> bool {
|
||||
let idx = idx as usize;
|
||||
if idx >= self.frags.len() { return false; }
|
||||
if self.frags[idx].is_none() {
|
||||
self.frags[idx] = Some(payload);
|
||||
self.received += 1;
|
||||
}
|
||||
self.received >= self.total
|
||||
}
|
||||
|
||||
fn assemble(self) -> Option<Vec<u8>> {
|
||||
let mut out = Vec::new();
|
||||
for frag in self.frags {
|
||||
out.extend_from_slice(&frag?);
|
||||
}
|
||||
Some(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct ClientState {
|
||||
/// msg_id → reassembly buffer
|
||||
reassembly: HashMap<u16, ReassemblyState>,
|
||||
/// Channel to push pending responses into; DNS handler polls this per-query
|
||||
#[allow(dead_code)]
|
||||
resp_tx: mpsc::Sender<Bytes>,
|
||||
last_seen: std::time::Instant,
|
||||
}
|
||||
|
||||
pub(crate) async fn start_dns_transport_server(
|
||||
config: DnsTransportConfig,
|
||||
udp_tx: mpsc::Sender<(Bytes, SocketAddr)>,
|
||||
|
|
@ -34,80 +102,218 @@ pub(crate) async fn start_dns_transport_server(
|
|||
tracing::info!("DNS Transport listening on {}", listen_addr);
|
||||
let _ = ui_event_tx.send(UiEvent::Log(format!("DNS Transport listening on {}", listen_addr)));
|
||||
|
||||
let mut buf = vec![0u8; 65535];
|
||||
loop {
|
||||
match socket.recv_from(&mut buf).await {
|
||||
Ok((size, peer)) => {
|
||||
let packet_bytes = buf[..size].to_vec();
|
||||
let udp_tx = udp_tx.clone();
|
||||
let tcp_map = tcp_map.clone();
|
||||
let socket = socket.clone();
|
||||
let base_domain = config.domain.clone();
|
||||
// Per-client state: ClientId → ClientState
|
||||
// Access is serialised by a single Mutex so fragments from the same client
|
||||
// are always reassembled atomically.
|
||||
let clients: Arc<tokio::sync::Mutex<HashMap<ClientId, ClientState>>> =
|
||||
Arc::new(tokio::sync::Mutex::new(HashMap::new()));
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Some(dns_req) = DnsPacket::decode(&packet_bytes) {
|
||||
if dns_req.questions.is_empty() { return; }
|
||||
let query = &dns_req.questions[0];
|
||||
|
||||
// Check if it's our target domain and it's a TXT or NULL query
|
||||
if (query.qtype == DnsRecordType::TXT || query.qtype == DnsRecordType::NULL) && query.name.ends_with(&base_domain) {
|
||||
// Decode base32 payload
|
||||
if let Some(payload) = decode_domain_to_payload(&query.name, &base_domain) {
|
||||
|
||||
let (resp_tx, mut resp_rx) = mpsc::channel::<Bytes>(10);
|
||||
|
||||
// Insert into tcp_map so Dispatcher routes responses to us
|
||||
tcp_map.write().await.insert(peer, resp_tx);
|
||||
|
||||
// Send payload to dispatcher
|
||||
if udp_tx.send((Bytes::from(payload), peer)).await.is_ok() {
|
||||
// Wait up to 50ms for any responses
|
||||
let mut responses = Vec::new();
|
||||
|
||||
while let Ok(Some(resp)) = tokio::time::timeout(Duration::from_millis(50), resp_rx.recv()).await {
|
||||
responses.push(resp);
|
||||
if responses.len() >= 3 { break; }
|
||||
}
|
||||
|
||||
// Remove from tcp_map
|
||||
tcp_map.write().await.remove(&peer);
|
||||
|
||||
// Build DNS Answer
|
||||
let mut dns_resp = DnsPacket::new_response(dns_req.id, &query.name, query.qtype.clone(), vec![]);
|
||||
dns_resp.answers.clear(); // We'll add our own
|
||||
|
||||
if !responses.is_empty() {
|
||||
for r in responses {
|
||||
dns_resp.answers.push(ostp_core::dns::DnsAnswer {
|
||||
name: query.name.clone(),
|
||||
rtype: query.qtype.clone(),
|
||||
rclass: 1,
|
||||
ttl: 0,
|
||||
rdata: r.to_vec(),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Empty response
|
||||
dns_resp.answers.push(ostp_core::dns::DnsAnswer {
|
||||
name: query.name.clone(),
|
||||
rtype: query.qtype.clone(),
|
||||
rclass: 1,
|
||||
ttl: 0,
|
||||
rdata: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
let resp_encoded = dns_resp.encode();
|
||||
let _ = socket.send_to(&resp_encoded, peer).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
// Cleanup task: evict stale client state
|
||||
{
|
||||
let clients_gc = clients.clone();
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(15)).await;
|
||||
let mut map = clients_gc.lock().await;
|
||||
map.retain(|_, v| v.last_seen.elapsed() < CLIENT_EXPIRY);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let base_domain = config.domain.clone();
|
||||
let mut buf = vec![0u8; 65535];
|
||||
|
||||
loop {
|
||||
let (size, peer) = match socket.recv_from(&mut buf).await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!("DNS Transport recv error: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let packet_bytes = buf[..size].to_vec();
|
||||
let udp_tx = udp_tx.clone();
|
||||
let tcp_map = tcp_map.clone();
|
||||
let socket = socket.clone();
|
||||
let clients = clients.clone();
|
||||
let base_domain = base_domain.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
handle_dns_query(
|
||||
packet_bytes, peer,
|
||||
udp_tx, tcp_map, socket, clients, base_domain,
|
||||
).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_dns_query(
|
||||
packet_bytes: Vec<u8>,
|
||||
peer: SocketAddr,
|
||||
udp_tx: mpsc::Sender<(Bytes, SocketAddr)>,
|
||||
tcp_map: Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>,
|
||||
socket: Arc<UdpSocket>,
|
||||
clients: Arc<tokio::sync::Mutex<HashMap<ClientId, ClientState>>>,
|
||||
base_domain: String,
|
||||
) {
|
||||
let dns_req = match DnsPacket::decode(&packet_bytes) {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
tracing::debug!("DNS: failed to decode packet from {}", peer);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if dns_req.questions.is_empty() { return; }
|
||||
let query = &dns_req.questions[0];
|
||||
|
||||
// Must be TXT query for our subdomain
|
||||
if query.qtype != DnsRecordType::TXT && query.qtype != DnsRecordType::NULL { return; }
|
||||
if !query.name.ends_with(&base_domain) { return; }
|
||||
|
||||
// Strip base domain and labels separator to get base32 subdomain
|
||||
let subdomain = {
|
||||
let name_lower = query.name.to_lowercase();
|
||||
let suffix = format!(".{}", base_domain.to_lowercase());
|
||||
let suffix_bare = base_domain.to_lowercase();
|
||||
let stripped = if name_lower.ends_with(&suffix) {
|
||||
&query.name[..name_lower.len() - suffix.len()]
|
||||
} else if name_lower == suffix_bare {
|
||||
""
|
||||
} else {
|
||||
return;
|
||||
};
|
||||
// Remove dots (label separators) to get contiguous base32
|
||||
stripped.replace('.', "")
|
||||
};
|
||||
|
||||
if subdomain.is_empty() {
|
||||
// Pure poll — no payload
|
||||
let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), vec![]);
|
||||
let _ = socket.send_to(&resp, peer).await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Base32-decode
|
||||
let raw = match base32_decode(&subdomain) {
|
||||
Some(b) => b,
|
||||
None => {
|
||||
tracing::debug!("DNS: base32 decode failed from {}", peer);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if raw.len() < HEADER_LEN {
|
||||
tracing::debug!("DNS: frame too short ({} bytes) from {}", raw.len(), peer);
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse header
|
||||
let client_id = ClientId(raw[..CLIENT_ID_LEN].try_into().unwrap());
|
||||
let msg_id = u16::from_be_bytes([raw[8], raw[9]]);
|
||||
let total_frags = raw[10];
|
||||
let frag_idx = raw[11];
|
||||
let payload = raw[HEADER_LEN..].to_vec();
|
||||
|
||||
tracing::trace!("DNS: client={} msg={} frag={}/{} payload={}B",
|
||||
hex::encode(&client_id.0), msg_id, frag_idx + 1, total_frags, payload.len());
|
||||
|
||||
// ── Reassembly ────────────────────────────────────────────────────────────
|
||||
let complete_packet: Option<Vec<u8>> = {
|
||||
let mut map = clients.lock().await;
|
||||
let state = map.entry(client_id).or_insert_with(|| {
|
||||
let (resp_tx, _) = mpsc::channel(64); // placeholder, will be replaced below
|
||||
ClientState {
|
||||
reassembly: HashMap::new(),
|
||||
resp_tx,
|
||||
last_seen: std::time::Instant::now(),
|
||||
}
|
||||
});
|
||||
state.last_seen = std::time::Instant::now();
|
||||
|
||||
if total_frags == 0 {
|
||||
// Empty poll — no data
|
||||
None
|
||||
} else if total_frags == 1 && payload.is_empty() {
|
||||
// Poll with empty payload
|
||||
None
|
||||
} else {
|
||||
let asm = state.reassembly
|
||||
.entry(msg_id)
|
||||
.or_insert_with(|| ReassemblyState::new(total_frags));
|
||||
|
||||
if asm.insert(frag_idx, payload) {
|
||||
// All fragments received — assemble and remove
|
||||
let complete = state.reassembly.remove(&msg_id)
|
||||
.and_then(|s| s.assemble());
|
||||
complete
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// ── Create per-query response channel ────────────────────────────────────
|
||||
// We use the peer SocketAddr as the routing key in tcp_map.
|
||||
// For each query we create a fresh one-shot channel.
|
||||
let (resp_tx, mut resp_rx) = mpsc::channel::<Bytes>(MAX_RESPONSE_PACKETS);
|
||||
tcp_map.write().await.insert(peer, resp_tx);
|
||||
|
||||
// ── Forward complete OSTP packet to dispatcher ────────────────────────────
|
||||
if let Some(ostp_pkt) = complete_packet {
|
||||
tracing::debug!("DNS: forwarding {}B OSTP packet from client={} to dispatcher",
|
||||
ostp_pkt.len(), hex::encode(&client_id.0));
|
||||
let _ = udp_tx.send((Bytes::from(ostp_pkt), peer)).await;
|
||||
}
|
||||
|
||||
// ── Wait for OSTP response(s) ─────────────────────────────────────────────
|
||||
let mut responses: Vec<Bytes> = Vec::new();
|
||||
let deadline = tokio::time::sleep(MAX_RESPONSE_DELAY);
|
||||
tokio::pin!(deadline);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = &mut deadline => break,
|
||||
resp = resp_rx.recv() => {
|
||||
match resp {
|
||||
Some(r) => {
|
||||
responses.push(r);
|
||||
if responses.len() >= MAX_RESPONSE_PACKETS { break; }
|
||||
}
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tcp_map.write().await.remove(&peer);
|
||||
|
||||
// ── Build DNS TXT response ────────────────────────────────────────────────
|
||||
// Bundle all response packets as length-prefixed data in TXT rdata:
|
||||
// [len_hi][len_lo][data...]...
|
||||
let mut rdata: Vec<u8> = Vec::new();
|
||||
for r in &responses {
|
||||
let len = r.len() as u16;
|
||||
rdata.push((len >> 8) as u8);
|
||||
rdata.push((len & 0xFF) as u8);
|
||||
rdata.extend_from_slice(r);
|
||||
}
|
||||
|
||||
tracing::trace!("DNS: responding to {} with {} OSTP packets ({} bytes rdata)",
|
||||
peer, responses.len(), rdata.len());
|
||||
|
||||
let resp = build_dns_response(&dns_req, &query.name, query.qtype.clone(), rdata);
|
||||
let _ = socket.send_to(&resp, peer).await;
|
||||
}
|
||||
|
||||
/// Build a DNS response packet with the given TXT rdata.
|
||||
fn build_dns_response(
|
||||
req: &DnsPacket,
|
||||
name: &str,
|
||||
rtype: DnsRecordType,
|
||||
rdata: Vec<u8>,
|
||||
) -> Vec<u8> {
|
||||
let resp = DnsPacket::new_response(req.id, name, rtype, rdata);
|
||||
resp.encode()
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue