use anyhow::Result; use bytes::Bytes; use std::collections::HashMap; use ostp_core::relay::RelayMessage; use tokio::io::AsyncReadExt; use tokio::net::UdpSocket; use tokio::sync::mpsc; use crate::dispatcher::Dispatcher; use crate::outbound::{self, OutboundConfig}; use crate::{RemoteState, UiEvent}; pub async fn handle_relay_message( peer_addr: std::net::SocketAddr, session_id: u32, stream_id: u16, payload: Bytes, dispatcher: &mut Dispatcher, socket: &UdpSocket, remotes: &mut HashMap<(u32, u16), RemoteState>, ui_event_tx: &mpsc::UnboundedSender, stream_tx: mpsc::UnboundedSender<(u32, u16, Vec)>, udp_reply_tx: mpsc::UnboundedSender<(u32, u16, String, Vec)>, connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, outbound_cfg: Option, dns_server: std::sync::Arc, debug: bool, tcp_map: &std::sync::Arc>>>, ) -> Result<()> { match RelayMessage::decode(&payload)? { RelayMessage::Connect(target) => { // Intercept DNS queries directed at the TUN gateway if our internal DNS is enabled let is_internal_dns = { target == "10.1.0.1:53" && dns_server.config.read().await.enabled }; if is_internal_dns { let client_ip = peer_addr.ip(); let dns_srv = dns_server.clone(); let stream_tx_dns = stream_tx.clone(); let (cancel_tx, _) = mpsc::channel::<()>(1); let (dns_query_tx, mut dns_query_rx) = mpsc::unbounded_channel::(); tokio::spawn(async move { if let Some(query_bytes) = dns_query_rx.recv().await { if let Some(resp_bytes) = dns_srv.resolve(&query_bytes, client_ip).await { let _ = stream_tx_dns.send((session_id, stream_id, resp_bytes)); } } let _ = stream_tx_dns.send((session_id, stream_id, Vec::new())); }); remotes.insert((session_id, stream_id), RemoteState { data_tx: dns_query_tx, udp_tx: None, cancel_tx, is_dns: true, }); send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx, tcp_map).await?; return Ok(()); } let mut connect_target = target.clone(); if connect_target.starts_with("10.1.0.1:") { connect_target = connect_target.replace("10.1.0.1:", "127.0.0.1:"); } let target_clone = connect_target.clone(); let connect_tx_clone = connect_tx.clone(); let stream_tx_clone = stream_tx.clone(); let outbound_clone = outbound_cfg.clone(); tokio::spawn(async move { let stream_res = outbound::connect_target(&target_clone, outbound_clone.as_ref(), debug).await; match stream_res { Ok(stream) => { let (mut reader, writer) = stream.into_split(); let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); tokio::spawn(async move { let mut buf = [0_u8; 4096]; loop { tokio::select! { _ = cancel_rx.recv() => break, read_res = reader.read(&mut buf) => { match read_res { Ok(0) | Err(_) => { let _ = stream_tx_clone.send((session_id, stream_id, Vec::new())); break; } Ok(n) => { if stream_tx_clone.send((session_id, stream_id, buf[..n].to_vec())).is_err() { break; } } } } } } }); let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Ok((writer, cancel_tx)))); } Err(e) => { let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Err(e.to_string()))); } } }); } RelayMessage::Data(data) => { if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) { let _ = remote.data_tx.send(bytes::Bytes::from(data)); } else { let _ = ui_event_tx.send(UiEvent::Log(format!("Relay DATA for unknown stream [{session_id}:{stream_id}] ({})", data.len()))); } } RelayMessage::KeepAlive => {} RelayMessage::Close => { if let Some(state) = remotes.remove(&(session_id, stream_id)) { let _ = state.cancel_tx.try_send(()); let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CLOSE [{session_id}:{stream_id}]"))); } } RelayMessage::ConnectOk => {} RelayMessage::Error(msg) => { let _ = ui_event_tx.send(UiEvent::Log(format!("Relay error from [{session_id}:{stream_id}]: {msg}"))); } RelayMessage::Ping(ts) => { send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx, tcp_map).await?; } RelayMessage::Pong(_) => {} RelayMessage::UdpAssociate => { if debug { let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP ASSOCIATE stream_id={stream_id}"))); } let udp_bind_result = match UdpSocket::bind("[::]:0").await { Ok(s) => Ok(s), Err(_) => UdpSocket::bind("0.0.0.0:0").await, }; let server_udp = match udp_bind_result { Ok(s) => std::sync::Arc::new(s), Err(e) => { let _ = ui_event_tx.send(UiEvent::Log(format!("UDP bind failed: {e}"))); return Ok(()); } }; let (udp_tx, mut udp_rx) = mpsc::unbounded_channel::<(String, Bytes)>(); let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); let (dummy_data_tx, _) = mpsc::unbounded_channel::(); // Outbound UDP loop (tunnel -> target) let tx_sock = server_udp.clone(); let dns_srv = dns_server.clone(); let udp_reply_clone_dns = udp_reply_tx.clone(); let client_ip = peer_addr.ip(); tokio::spawn(async move { while let Some((target, data)) = udp_rx.recv().await { let is_internal_dns = target == "10.1.0.1:53" && dns_srv.config.read().await.enabled; if is_internal_dns { if let Some(resp_bytes) = dns_srv.resolve(&data, client_ip).await { let _ = udp_reply_clone_dns.send((session_id, stream_id, target, resp_bytes)); } } else { let mut forward_target = target.clone(); if forward_target.starts_with("10.1.0.1:") { forward_target = forward_target.replace("10.1.0.1:", "127.0.0.1:"); } let _ = tx_sock.send_to(&data, &forward_target).await; } } }); // Inbound UDP loop (target -> tunnel) let rx_sock = server_udp.clone(); let udp_reply_clone = udp_reply_tx.clone(); tokio::spawn(async move { let mut buf = vec![0u8; 65536]; loop { tokio::select! { _ = cancel_rx.recv() => break, res = rx_sock.recv_from(&mut buf) => { match res { Ok((len, addr)) => { let _ = udp_reply_clone.send((session_id, stream_id, addr.to_string(), buf[..len].to_vec())); } Err(_) => break, } } } } }); remotes.insert((session_id, stream_id), RemoteState { data_tx: dummy_data_tx, udp_tx: Some(udp_tx), cancel_tx, is_dns: false, }); send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx, tcp_map).await?; } RelayMessage::UdpData(target, data) => { if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) { if let Some(ref udp_tx) = remote.udp_tx { let _ = udp_tx.send((target, Bytes::from(data))); } } else { let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP DATA for unknown stream [{session_id}:{stream_id}]"))); } } } Ok(()) } pub async fn send_relay_to_stream( session_id: u32, stream_id: u16, msg: RelayMessage, dispatcher: &mut Dispatcher, socket: &UdpSocket, ui_event_tx: &mpsc::UnboundedSender, tcp_map: &std::sync::Arc>>>, ) -> Result<()> { let payload = Bytes::from(msg.encode()); if let Some((frame, peer_addr)) = dispatcher.outbound_to_session(session_id, stream_id, payload)? { let response_len = frame.len(); let mut sent_tcp = false; { let map = tcp_map.read().await; if let Some(tx) = map.get(&peer_addr) { let _ = tx.try_send(frame.clone()); sent_tcp = true; } } if !sent_tcp { let _ = socket.send_to(&frame, peer_addr).await?; } let _ = ui_event_tx.send(UiEvent::Tx { peer: peer_addr.ip(), bytes: response_len, }); } Ok(()) }