diff --git a/ostp-server/src/lib.rs b/ostp-server/src/lib.rs index 13bae98..2c57726 100644 --- a/ostp-server/src/lib.rs +++ b/ostp-server/src/lib.rs @@ -554,6 +554,7 @@ async fn run_server_loop( stream_tx.clone(), connect_tx.clone(), outbound.clone(), + dns_server.clone(), debug, ).await?; } diff --git a/ostp-server/src/relay.rs b/ostp-server/src/relay.rs index d61c3bd..7736603 100644 --- a/ostp-server/src/relay.rs +++ b/ostp-server/src/relay.rs @@ -23,11 +23,43 @@ pub async fn handle_relay_message( stream_tx: mpsc::UnboundedSender<(u32, u16, Vec)>, connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, outbound_cfg: Option, + dns_server: std::sync::Arc, debug: bool, ) -> Result<()> { match RelayMessage::decode(&payload)? { RelayMessage::Connect(target) => { - let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT start for [{session_id}:{stream_id}] -> {target}"))); + // Intercept DNS queries directed at the TUN gateway if our internal DNS is enabled + let is_internal_dns = { + target == "10.1.0.1:53" && dns_server.config.read().await.enabled + }; + + if is_internal_dns { + let client_ip = peer_addr.ip(); + let dns_srv = dns_server.clone(); + let stream_tx_dns = stream_tx.clone(); + let (cancel_tx, _) = mpsc::channel::<()>(1); + + let (dns_query_tx, mut dns_query_rx) = mpsc::unbounded_channel::(); + + tokio::spawn(async move { + if let Some(query_bytes) = dns_query_rx.recv().await { + if let Some(resp_bytes) = dns_srv.resolve(&query_bytes, client_ip).await { + let _ = stream_tx_dns.send((session_id, stream_id, resp_bytes)); + } + } + let _ = stream_tx_dns.send((session_id, stream_id, Vec::new())); + }); + + remotes.insert((session_id, stream_id), RemoteState { + data_tx: dns_query_tx, + cancel_tx, + is_dns: true, + }); + + send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx).await?; + return Ok(()); + } + let target_clone = target.clone(); let connect_tx_clone = connect_tx.clone(); let stream_tx_clone = stream_tx.clone();