mirror of https://github.com/ospab/ostp.git
246 lines
11 KiB
Rust
246 lines
11 KiB
Rust
use anyhow::Result;
|
|
use bytes::Bytes;
|
|
use std::collections::HashMap;
|
|
|
|
use ostp_core::relay::RelayMessage;
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::net::UdpSocket;
|
|
use tokio::sync::mpsc;
|
|
|
|
use crate::dispatcher::Dispatcher;
|
|
use crate::outbound::{self, OutboundConfig};
|
|
use crate::{RemoteState, UiEvent};
|
|
|
|
pub async fn handle_relay_message(
|
|
peer_addr: std::net::SocketAddr,
|
|
session_id: u32,
|
|
stream_id: u16,
|
|
payload: Bytes,
|
|
dispatcher: &mut Dispatcher,
|
|
socket: &UdpSocket,
|
|
remotes: &mut HashMap<(u32, u16), RemoteState>,
|
|
ui_event_tx: &mpsc::UnboundedSender<UiEvent>,
|
|
stream_tx: mpsc::UnboundedSender<(u32, u16, Vec<u8>)>,
|
|
udp_reply_tx: mpsc::UnboundedSender<(u32, u16, String, Vec<u8>)>,
|
|
connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>,
|
|
outbound_cfg: Option<OutboundConfig>,
|
|
dns_server: std::sync::Arc<crate::dns::DnsServer>,
|
|
debug: bool,
|
|
tcp_map: &std::sync::Arc<tokio::sync::RwLock<HashMap<std::net::SocketAddr, tokio::sync::mpsc::Sender<Bytes>>>>,
|
|
) -> Result<()> {
|
|
match RelayMessage::decode(&payload)? {
|
|
RelayMessage::Connect(target) => {
|
|
// Intercept DNS queries directed at the TUN gateway if our internal DNS is enabled
|
|
let is_internal_dns = {
|
|
target == "10.1.0.1:53" && dns_server.config.read().await.enabled
|
|
};
|
|
|
|
if is_internal_dns {
|
|
let client_ip = peer_addr.ip();
|
|
let dns_srv = dns_server.clone();
|
|
let stream_tx_dns = stream_tx.clone();
|
|
let (cancel_tx, _) = mpsc::channel::<()>(1);
|
|
|
|
let (dns_query_tx, mut dns_query_rx) = mpsc::unbounded_channel::<Bytes>();
|
|
|
|
tokio::spawn(async move {
|
|
if let Some(query_bytes) = dns_query_rx.recv().await {
|
|
if let Some(resp_bytes) = dns_srv.resolve(&query_bytes, client_ip).await {
|
|
let _ = stream_tx_dns.send((session_id, stream_id, resp_bytes));
|
|
}
|
|
}
|
|
let _ = stream_tx_dns.send((session_id, stream_id, Vec::new()));
|
|
});
|
|
|
|
remotes.insert((session_id, stream_id), RemoteState {
|
|
data_tx: dns_query_tx,
|
|
udp_tx: None,
|
|
cancel_tx,
|
|
is_dns: true,
|
|
});
|
|
|
|
send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx, tcp_map).await?;
|
|
return Ok(());
|
|
}
|
|
|
|
let mut connect_target = target.clone();
|
|
if connect_target.starts_with("10.1.0.1:") {
|
|
connect_target = connect_target.replace("10.1.0.1:", "127.0.0.1:");
|
|
}
|
|
|
|
let target_clone = connect_target.clone();
|
|
let connect_tx_clone = connect_tx.clone();
|
|
let stream_tx_clone = stream_tx.clone();
|
|
let outbound_clone = outbound_cfg.clone();
|
|
tokio::spawn(async move {
|
|
let stream_res = outbound::connect_target(&target_clone, outbound_clone.as_ref(), debug).await;
|
|
match stream_res {
|
|
Ok(stream) => {
|
|
let (mut reader, writer) = stream.into_split();
|
|
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
|
|
tokio::spawn(async move {
|
|
let mut buf = [0_u8; 4096];
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel_rx.recv() => break,
|
|
read_res = reader.read(&mut buf) => {
|
|
match read_res {
|
|
Ok(0) | Err(_) => {
|
|
let _ = stream_tx_clone.send((session_id, stream_id, Vec::new()));
|
|
break;
|
|
}
|
|
Ok(n) => {
|
|
if stream_tx_clone.send((session_id, stream_id, buf[..n].to_vec())).is_err() {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Ok((writer, cancel_tx))));
|
|
}
|
|
Err(e) => {
|
|
let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Err(e.to_string())));
|
|
}
|
|
}
|
|
});
|
|
}
|
|
RelayMessage::Data(data) => {
|
|
if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) {
|
|
let _ = remote.data_tx.send(bytes::Bytes::from(data));
|
|
} else {
|
|
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay DATA for unknown stream [{session_id}:{stream_id}] ({})", data.len())));
|
|
}
|
|
}
|
|
RelayMessage::KeepAlive => {}
|
|
RelayMessage::Close => {
|
|
if let Some(state) = remotes.remove(&(session_id, stream_id)) {
|
|
let _ = state.cancel_tx.try_send(());
|
|
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CLOSE [{session_id}:{stream_id}]")));
|
|
}
|
|
}
|
|
RelayMessage::ConnectOk => {}
|
|
RelayMessage::Error(msg) => {
|
|
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay error from [{session_id}:{stream_id}]: {msg}")));
|
|
}
|
|
RelayMessage::Ping(ts) => {
|
|
send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx, tcp_map).await?;
|
|
}
|
|
RelayMessage::Pong(_) => {}
|
|
RelayMessage::UdpAssociate => {
|
|
if debug {
|
|
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP ASSOCIATE stream_id={stream_id}")));
|
|
}
|
|
let udp_bind_result = match UdpSocket::bind("[::]:0").await {
|
|
Ok(s) => Ok(s),
|
|
Err(_) => UdpSocket::bind("0.0.0.0:0").await,
|
|
};
|
|
let server_udp = match udp_bind_result {
|
|
Ok(s) => std::sync::Arc::new(s),
|
|
Err(e) => {
|
|
let _ = ui_event_tx.send(UiEvent::Log(format!("UDP bind failed: {e}")));
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
let (udp_tx, mut udp_rx) = mpsc::unbounded_channel::<(String, Bytes)>();
|
|
let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1);
|
|
let (dummy_data_tx, _) = mpsc::unbounded_channel::<Bytes>();
|
|
|
|
// Outbound UDP loop (tunnel -> target)
|
|
let tx_sock = server_udp.clone();
|
|
let dns_srv = dns_server.clone();
|
|
let udp_reply_clone_dns = udp_reply_tx.clone();
|
|
let client_ip = peer_addr.ip();
|
|
tokio::spawn(async move {
|
|
while let Some((target, data)) = udp_rx.recv().await {
|
|
let is_internal_dns = target == "10.1.0.1:53" && dns_srv.config.read().await.enabled;
|
|
if is_internal_dns {
|
|
if let Some(resp_bytes) = dns_srv.resolve(&data, client_ip).await {
|
|
let _ = udp_reply_clone_dns.send((session_id, stream_id, target, resp_bytes));
|
|
}
|
|
} else {
|
|
let mut forward_target = target.clone();
|
|
if forward_target.starts_with("10.1.0.1:") {
|
|
forward_target = forward_target.replace("10.1.0.1:", "127.0.0.1:");
|
|
}
|
|
let _ = tx_sock.send_to(&data, &forward_target).await;
|
|
}
|
|
}
|
|
});
|
|
|
|
// Inbound UDP loop (target -> tunnel)
|
|
let rx_sock = server_udp.clone();
|
|
let udp_reply_clone = udp_reply_tx.clone();
|
|
tokio::spawn(async move {
|
|
let mut buf = vec![0u8; 65536];
|
|
loop {
|
|
tokio::select! {
|
|
_ = cancel_rx.recv() => break,
|
|
res = rx_sock.recv_from(&mut buf) => {
|
|
match res {
|
|
Ok((len, addr)) => {
|
|
let _ = udp_reply_clone.send((session_id, stream_id, addr.to_string(), buf[..len].to_vec()));
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
remotes.insert((session_id, stream_id), RemoteState {
|
|
data_tx: dummy_data_tx,
|
|
udp_tx: Some(udp_tx),
|
|
cancel_tx,
|
|
is_dns: false,
|
|
});
|
|
|
|
send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, dispatcher, socket, ui_event_tx, tcp_map).await?;
|
|
}
|
|
RelayMessage::UdpData(target, data) => {
|
|
if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) {
|
|
if let Some(ref udp_tx) = remote.udp_tx {
|
|
let _ = udp_tx.send((target, Bytes::from(data)));
|
|
}
|
|
} else {
|
|
let _ = ui_event_tx.send(UiEvent::Log(format!("Relay UDP DATA for unknown stream [{session_id}:{stream_id}]")));
|
|
}
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn send_relay_to_stream(
|
|
session_id: u32,
|
|
stream_id: u16,
|
|
msg: RelayMessage,
|
|
dispatcher: &mut Dispatcher,
|
|
socket: &UdpSocket,
|
|
ui_event_tx: &mpsc::UnboundedSender<UiEvent>,
|
|
tcp_map: &std::sync::Arc<tokio::sync::RwLock<HashMap<std::net::SocketAddr, tokio::sync::mpsc::Sender<Bytes>>>>,
|
|
) -> Result<()> {
|
|
let payload = Bytes::from(msg.encode());
|
|
if let Some((frame, peer_addr)) = dispatcher.outbound_to_session(session_id, stream_id, payload)? {
|
|
let response_len = frame.len();
|
|
let mut sent_tcp = false;
|
|
{
|
|
let map = tcp_map.read().await;
|
|
if let Some(tx) = map.get(&peer_addr) {
|
|
let _ = tx.try_send(frame.clone());
|
|
sent_tcp = true;
|
|
}
|
|
}
|
|
if !sent_tcp {
|
|
let _ = socket.send_to(&frame, peer_addr).await?;
|
|
}
|
|
let _ = ui_event_tx.send(UiEvent::Tx {
|
|
peer: peer_addr.ip(),
|
|
bytes: response_len,
|
|
});
|
|
}
|
|
Ok(())
|
|
}
|