use std::{ net::IpAddr, pin::Pin, task::{ready, Context, Poll}, }; use futures::{Sink, Stream}; use smoltcp::wire::IpProtocol; use tokio::sync::mpsc::{channel, Receiver}; use tokio_util::sync::PollSender; use tracing::{debug, trace}; use crate::{ filter::{IpFilter, IpFilters}, packet::{AnyIpPktFrame, IpPacket}, runner::Runner, tcp::TcpListener, udp::UdpSocket, }; pub struct StackBuilder { enable_udp: bool, enable_tcp: bool, enable_icmp: bool, stack_buffer_size: usize, udp_buffer_size: usize, tcp_buffer_size: usize, mtu: usize, ip_filters: IpFilters<'static>, } impl Default for StackBuilder { fn default() -> Self { Self { enable_udp: false, enable_tcp: false, enable_icmp: false, stack_buffer_size: 1024, udp_buffer_size: 512, tcp_buffer_size: 512, mtu: 1504, // 1500 for Ethernet + 4 for VLAN ip_filters: IpFilters::with_non_broadcast(), } } } #[allow(unused)] impl StackBuilder { pub fn enable_udp(mut self, enable: bool) -> Self { self.enable_udp = enable; self } pub fn enable_tcp(mut self, enable: bool) -> Self { self.enable_tcp = enable; self } pub fn enable_icmp(mut self, enable: bool) -> Self { self.enable_icmp = enable; self } pub fn stack_buffer_size(mut self, size: usize) -> Self { self.stack_buffer_size = size; self } pub fn udp_buffer_size(mut self, size: usize) -> Self { self.udp_buffer_size = size; self } pub fn tcp_buffer_size(mut self, size: usize) -> Self { self.tcp_buffer_size = size; self } pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self { self.ip_filters = filters; self } pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self { self.ip_filters.add(filter); self } pub fn add_ip_filter_fn(mut self, filter: F) -> Self where F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static, { self.ip_filters.add_fn(filter); self } pub fn mtu(mut self, mtu: usize) -> Self { self.mtu = mtu; self } #[allow(clippy::type_complexity)] pub fn build( self, ) -> std::io::Result<( Stack, Option, Option, Option, )> { let (stack_tx, stack_rx) = channel(self.stack_buffer_size); let (udp_tx, udp_rx) = if self.enable_udp { let (udp_tx, udp_rx) = channel(self.udp_buffer_size); (Some(PollSender::new(udp_tx)), Some(udp_rx)) } else { (None, None) }; let (tcp_tx, tcp_rx) = if self.enable_tcp { let (tcp_tx, tcp_rx) = channel(self.tcp_buffer_size); (Some(PollSender::new(tcp_tx)), Some(tcp_rx)) } else { (None, None) }; // ICMP is handled by TCP's Interface. // smoltcp's interface will always send replies to EchoRequest if self.enable_icmp && !self.enable_tcp { use std::io::{Error, ErrorKind::InvalidInput}; return Err(Error::new(InvalidInput, "ICMP requires TCP")); } let icmp_tx = if self.enable_icmp { tcp_tx.clone() } else { None }; let udp_socket = udp_rx.map(|udp_rx| UdpSocket::new(udp_rx, stack_tx.clone())); let (tcp_runner, tcp_listener) = if let Some(tcp_rx) = tcp_rx { let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx, self.mtu)?; (Some(tcp_runner), Some(tcp_listener)) } else { (None, None) }; let stack = Stack { ip_filters: self.ip_filters, stack_rx, sink_buf: None, udp_tx, tcp_tx, icmp_tx, }; Ok((stack, tcp_runner, udp_socket, tcp_listener)) } } pub struct Stack { ip_filters: IpFilters<'static>, sink_buf: Option<(AnyIpPktFrame, IpProtocol)>, udp_tx: Option>, tcp_tx: Option>, icmp_tx: Option>, stack_rx: Receiver, } impl Stack { fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll> { let (item, proto) = match self.sink_buf.take() { Some(val) => val, None => return Poll::Ready(Ok(())), }; let tx = match proto { IpProtocol::Tcp => self.tcp_tx.as_mut(), IpProtocol::Udp => self.udp_tx.as_mut(), IpProtocol::Icmp | IpProtocol::Icmpv6 => self.icmp_tx.as_mut(), _ => unreachable!(), }; let Some(tx) = tx else { return Poll::Ready(Ok(())); }; match tx.poll_reserve(cx) { Poll::Pending => { self.sink_buf = Some((item, proto)); Poll::Pending } Poll::Ready(Err(_)) => Poll::Ready(Err(channel_closed_err("channel is closed"))), Poll::Ready(Ok(_)) => match tx.send_item(item) { Ok(()) => Poll::Ready(Ok(())), Err(_) => Poll::Ready(Err(channel_closed_err("channel is closed"))), }, } } } // Recv from stack. impl Stream for Stack { type Item = std::io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.stack_rx.poll_recv(cx) { Poll::Ready(Some(pkt)) => Poll::Ready(Some(Ok(pkt))), Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } } // Send to stack. impl Sink for Stack { type Error = std::io::Error; fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // If a buffered item exists, try to flush it first. This also properly // registers the waker via poll_reserve so we get woken when the channel // has capacity. Without this, returning Pending here with _cx unused // means the task never gets rescheduled. if self.sink_buf.is_some() { ready!(self.poll_send(cx))?; } Poll::Ready(Ok(())) } fn start_send(mut self: Pin<&mut Self>, item: AnyIpPktFrame) -> Result<(), Self::Error> { if item.is_empty() { return Ok(()); } use std::io::{Error, ErrorKind::InvalidInput}; let packet = IpPacket::new_checked(item.as_slice()) .map_err(|err| Error::new(InvalidInput, format!("invalid IP packet: {err}")))?; let src_ip = packet.src_addr(); let dst_ip = packet.dst_addr(); let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip); if !addr_allowed { trace!("IP packet {src_ip} -> {dst_ip} (allowed? {addr_allowed}) throwing away",); return Ok(()); } let protocol = packet.protocol(); if matches!( protocol, IpProtocol::Tcp | IpProtocol::Udp | IpProtocol::Icmp | IpProtocol::Icmpv6 ) { self.sink_buf.replace((item, protocol)); } else { debug!("tun IP packet ignored (protocol: {:?})", protocol); } Ok(()) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.poll_send(cx) } fn poll_close( mut self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { self.stack_rx.close(); Poll::Ready(Ok(())) } } fn channel_closed_err(err: E) -> std::io::Error where E: Into>, { std::io::Error::new(std::io::ErrorKind::BrokenPipe, err) }