use std::{ net::SocketAddr, pin::Pin, task::{Context, Poll}, }; use etherparse::PacketBuilder; use futures::{ready, Sink, SinkExt, Stream}; use smoltcp::wire::UdpPacket; use tokio::sync::mpsc::{Receiver, Sender}; use tokio_util::sync::PollSender; use tracing::{error, trace}; use crate::packet::{AnyIpPktFrame, IpPacket}; pub type UdpMsg = ( Vec, /* payload */ SocketAddr, /* local */ SocketAddr, /* remote */ ); pub struct UdpSocket { udp_rx: Receiver, stack_tx: PollSender, } impl UdpSocket { pub(super) fn new(udp_rx: Receiver, stack_tx: Sender) -> Self { Self { udp_rx, stack_tx: PollSender::new(stack_tx), } } pub fn split(self) -> (ReadHalf, WriteHalf) { ( ReadHalf { udp_rx: self.udp_rx, }, WriteHalf { stack_tx: self.stack_tx, }, ) } } pub struct ReadHalf { udp_rx: Receiver, } pub struct WriteHalf { stack_tx: PollSender, } impl Stream for ReadHalf { type Item = UdpMsg; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { loop { match ready!(self.udp_rx.poll_recv(cx)) { Some(frame) => { let packet = match IpPacket::new_checked(frame.as_slice()) { Ok(p) => p, Err(err) => { error!("invalid IP packet: {}", err); continue; } }; let src_ip = packet.src_addr(); let dst_ip = packet.dst_addr(); let payload = packet.payload(); let packet = match UdpPacket::new_checked(payload) { Ok(p) => p, Err(err) => { error!("invalid err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}"); continue; } }; let src_port = packet.src_port(); let dst_port = packet.dst_port(); let src_addr = SocketAddr::new(src_ip, src_port); let dst_addr = SocketAddr::new(dst_ip, dst_port); trace!("created UDP socket for {} <-> {}", src_addr, dst_addr); return Poll::Ready(Some((packet.payload().to_vec(), src_addr, dst_addr))); } None => return Poll::Ready(None), } } } } impl Sink for WriteHalf { type Error = std::io::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match ready!(self.stack_tx.poll_ready_unpin(cx)) { Ok(()) => Poll::Ready(Ok(())), Err(err) => Poll::Ready(Err(std::io::Error::other(err))), } } fn start_send(mut self: Pin<&mut Self>, item: UdpMsg) -> Result<(), Self::Error> { use std::io::{Error, ErrorKind::InvalidData}; let (data, src_addr, dst_addr) = item; if data.is_empty() { return Ok(()); } let builder = match (src_addr, dst_addr) { (SocketAddr::V4(src), SocketAddr::V4(dst)) => { PacketBuilder::ipv4(src.ip().octets(), dst.ip().octets(), 20) .udp(src_addr.port(), dst_addr.port()) } (SocketAddr::V6(src), SocketAddr::V6(dst)) => { PacketBuilder::ipv6(src.ip().octets(), dst.ip().octets(), 20) .udp(src_addr.port(), dst_addr.port()) } _ => { return Err(Error::new(InvalidData, "src or destination type unmatch")); } }; let mut ip_packet_writer = Vec::with_capacity(builder.size(data.len())); builder .write(&mut ip_packet_writer, &data) .map_err(|err| Error::other(format!("PacketBuilder::write: {err}")))?; match self.stack_tx.start_send_unpin(ip_packet_writer) { Ok(()) => Ok(()), Err(err) => Err(Error::other(format!("send error: {err}"))), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use std::io::Error; match ready!(self.stack_tx.poll_flush_unpin(cx)) { Ok(()) => Poll::Ready(Ok(())), Err(err) => Poll::Ready(Err(Error::other(format!("flush error: {err}")))), } } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { use std::io::Error; match ready!(self.stack_tx.poll_close_unpin(cx)) { Ok(()) => Poll::Ready(Ok(())), Err(err) => Poll::Ready(Err(Error::other(format!("close error: {err}")))), } } }