mirror of https://github.com/ospab/ostp.git
156 lines
4.9 KiB
Rust
156 lines
4.9 KiB
Rust
use std::{
|
|
net::SocketAddr,
|
|
pin::Pin,
|
|
task::{Context, Poll},
|
|
};
|
|
|
|
use etherparse::PacketBuilder;
|
|
use futures::{ready, Sink, SinkExt, Stream};
|
|
use smoltcp::wire::UdpPacket;
|
|
use tokio::sync::mpsc::{Receiver, Sender};
|
|
use tokio_util::sync::PollSender;
|
|
use tracing::{error, trace};
|
|
|
|
use crate::packet::{AnyIpPktFrame, IpPacket};
|
|
|
|
pub type UdpMsg = (
|
|
Vec<u8>, /* payload */
|
|
SocketAddr, /* local */
|
|
SocketAddr, /* remote */
|
|
);
|
|
|
|
pub struct UdpSocket {
|
|
udp_rx: Receiver<AnyIpPktFrame>,
|
|
stack_tx: PollSender<AnyIpPktFrame>,
|
|
}
|
|
|
|
impl UdpSocket {
|
|
pub(super) fn new(udp_rx: Receiver<AnyIpPktFrame>, stack_tx: Sender<AnyIpPktFrame>) -> Self {
|
|
Self {
|
|
udp_rx,
|
|
stack_tx: PollSender::new(stack_tx),
|
|
}
|
|
}
|
|
|
|
pub fn split(self) -> (ReadHalf, WriteHalf) {
|
|
(
|
|
ReadHalf {
|
|
udp_rx: self.udp_rx,
|
|
},
|
|
WriteHalf {
|
|
stack_tx: self.stack_tx,
|
|
},
|
|
)
|
|
}
|
|
}
|
|
|
|
pub struct ReadHalf {
|
|
udp_rx: Receiver<AnyIpPktFrame>,
|
|
}
|
|
|
|
pub struct WriteHalf {
|
|
stack_tx: PollSender<AnyIpPktFrame>,
|
|
}
|
|
|
|
impl Stream for ReadHalf {
|
|
type Item = UdpMsg;
|
|
|
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
|
loop {
|
|
match ready!(self.udp_rx.poll_recv(cx)) {
|
|
Some(frame) => {
|
|
let packet = match IpPacket::new_checked(frame.as_slice()) {
|
|
Ok(p) => p,
|
|
Err(err) => {
|
|
error!("invalid IP packet: {}", err);
|
|
continue;
|
|
}
|
|
};
|
|
|
|
let src_ip = packet.src_addr();
|
|
let dst_ip = packet.dst_addr();
|
|
let payload = packet.payload();
|
|
|
|
let packet = match UdpPacket::new_checked(payload) {
|
|
Ok(p) => p,
|
|
Err(err) => {
|
|
error!("invalid err: {err}, src_ip: {src_ip}, dst_ip: {dst_ip}, payload: {payload:?}");
|
|
continue;
|
|
}
|
|
};
|
|
let src_port = packet.src_port();
|
|
let dst_port = packet.dst_port();
|
|
|
|
let src_addr = SocketAddr::new(src_ip, src_port);
|
|
let dst_addr = SocketAddr::new(dst_ip, dst_port);
|
|
|
|
trace!("created UDP socket for {} <-> {}", src_addr, dst_addr);
|
|
|
|
return Poll::Ready(Some((packet.payload().to_vec(), src_addr, dst_addr)));
|
|
}
|
|
None => return Poll::Ready(None),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Sink<UdpMsg> for WriteHalf {
|
|
type Error = std::io::Error;
|
|
|
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
match ready!(self.stack_tx.poll_ready_unpin(cx)) {
|
|
Ok(()) => Poll::Ready(Ok(())),
|
|
Err(err) => Poll::Ready(Err(std::io::Error::other(err))),
|
|
}
|
|
}
|
|
|
|
fn start_send(mut self: Pin<&mut Self>, item: UdpMsg) -> Result<(), Self::Error> {
|
|
use std::io::{Error, ErrorKind::InvalidData};
|
|
let (data, src_addr, dst_addr) = item;
|
|
|
|
if data.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
let builder = match (src_addr, dst_addr) {
|
|
(SocketAddr::V4(src), SocketAddr::V4(dst)) => {
|
|
PacketBuilder::ipv4(src.ip().octets(), dst.ip().octets(), 20)
|
|
.udp(src_addr.port(), dst_addr.port())
|
|
}
|
|
(SocketAddr::V6(src), SocketAddr::V6(dst)) => {
|
|
PacketBuilder::ipv6(src.ip().octets(), dst.ip().octets(), 20)
|
|
.udp(src_addr.port(), dst_addr.port())
|
|
}
|
|
_ => {
|
|
return Err(Error::new(InvalidData, "src or destination type unmatch"));
|
|
}
|
|
};
|
|
|
|
let mut ip_packet_writer = Vec::with_capacity(builder.size(data.len()));
|
|
builder
|
|
.write(&mut ip_packet_writer, &data)
|
|
.map_err(|err| Error::other(format!("PacketBuilder::write: {err}")))?;
|
|
|
|
match self.stack_tx.start_send_unpin(ip_packet_writer) {
|
|
Ok(()) => Ok(()),
|
|
Err(err) => Err(Error::other(format!("send error: {err}"))),
|
|
}
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
use std::io::Error;
|
|
match ready!(self.stack_tx.poll_flush_unpin(cx)) {
|
|
Ok(()) => Poll::Ready(Ok(())),
|
|
Err(err) => Poll::Ready(Err(Error::other(format!("flush error: {err}")))),
|
|
}
|
|
}
|
|
|
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
use std::io::Error;
|
|
match ready!(self.stack_tx.poll_close_unpin(cx)) {
|
|
Ok(()) => Poll::Ready(Ok(())),
|
|
Err(err) => Poll::Ready(Err(Error::other(format!("close error: {err}")))),
|
|
}
|
|
}
|
|
}
|