fix(xhttp): rewrite RealityStream buffering to prevent packet drops and data loss

This commit is contained in:
ospab 2026-05-30 01:10:29 +03:00
parent 7257da174a
commit 902e762c91
3 changed files with 109 additions and 58 deletions

View File

@ -108,6 +108,8 @@ struct RealityStream {
rx_nonce: u64, rx_nonce: u64,
tx_nonce: u64, tx_nonce: u64,
rx_buf: BytesMut, rx_buf: BytesMut,
plaintext_buf: BytesMut,
tx_buf: BytesMut,
} }
impl RealityStream { impl RealityStream {
@ -118,6 +120,8 @@ impl RealityStream {
rx_nonce: 0, rx_nonce: 0,
tx_nonce: 0, tx_nonce: 0,
rx_buf: BytesMut::with_capacity(16384), rx_buf: BytesMut::with_capacity(16384),
plaintext_buf: BytesMut::new(),
tx_buf: BytesMut::new(),
} }
} }
@ -131,11 +135,16 @@ impl RealityStream {
impl tokio::io::AsyncRead for RealityStream { impl tokio::io::AsyncRead for RealityStream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> { fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
loop { loop {
// Try to decode a full record if !self.plaintext_buf.is_empty() {
let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
buf.put_slice(&self.plaintext_buf[..out_len]);
self.plaintext_buf.advance(out_len);
return Poll::Ready(Ok(()));
}
if self.rx_buf.len() >= 5 { if self.rx_buf.len() >= 5 {
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize; let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
if self.rx_buf.len() >= 5 + len { if self.rx_buf.len() >= 5 + len {
// We have a full record
if self.rx_buf[0] != 0x17 { if self.rx_buf[0] != 0x17 {
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record"))); return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record")));
} }
@ -147,17 +156,9 @@ impl tokio::io::AsyncRead for RealityStream {
match self.data_key.decrypt(nonce, ciphertext) { match self.data_key.decrypt(nonce, ciphertext) {
Ok(plaintext) => { Ok(plaintext) => {
self.rx_nonce += 1; self.rx_nonce += 1;
let out_len = std::cmp::min(buf.remaining(), plaintext.len()); self.plaintext_buf.put_slice(&plaintext);
buf.put_slice(&plaintext[..out_len]);
if out_len < plaintext.len() {
// RealityStream doesn't buffer remaining plaintext if user buffer is too small.
// In xhttp_handshake_and_loop we always use 65535 byte buffers, so it fits.
// If needed, we'd add an internal plaintext_buffer.
}
self.rx_buf.advance(5 + len); self.rx_buf.advance(5 + len);
return Poll::Ready(Ok(())); continue;
} }
Err(_) => { Err(_) => {
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))); return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed")));
@ -166,13 +167,12 @@ impl tokio::io::AsyncRead for RealityStream {
} }
} }
// Need more data let mut read_buf = [0u8; 8192];
let mut read_buf = [0u8; 4096];
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf); let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) { match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
if tokio_buf.filled().is_empty() { if tokio_buf.filled().is_empty() {
return Poll::Ready(Ok(())); // EOF return Poll::Ready(Ok(()));
} }
self.rx_buf.put_slice(tokio_buf.filled()); self.rx_buf.put_slice(tokio_buf.filled());
} }
@ -185,42 +185,60 @@ impl tokio::io::AsyncRead for RealityStream {
impl tokio::io::AsyncWrite for RealityStream { impl tokio::io::AsyncWrite for RealityStream {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> { fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
let nonce_bytes = Self::make_nonce(self.tx_nonce); let this = self.get_mut();
while !this.tx_buf.is_empty() {
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
let nonce_bytes = Self::make_nonce(this.tx_nonce);
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
// Encrypt the entire buf as a single record match this.data_key.encrypt(nonce, buf) {
match self.data_key.encrypt(nonce, buf) {
Ok(ciphertext) => { Ok(ciphertext) => {
let mut record = BytesMut::with_capacity(5 + ciphertext.len()); this.tx_nonce += 1;
record.put_u8(0x17); // Application Data this.tx_buf.reserve(5 + ciphertext.len());
record.put_u16(0x0303); // TLS 1.2/1.3 this.tx_buf.put_u8(0x17);
record.put_u16(ciphertext.len() as u16); this.tx_buf.put_u16(0x0303);
record.put_slice(&ciphertext); this.tx_buf.put_u16(ciphertext.len() as u16);
this.tx_buf.put_slice(&ciphertext);
// Write the full record to the inner stream match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) { Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Ok(n)) if n == record.len() => { Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
self.tx_nonce += 1; Poll::Pending => {}
Poll::Ready(Ok(buf.len()))
}
Poll::Ready(Ok(_n)) => {
// Partial writes of a single TLS record are not supported by this simple wrapper
Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported")))
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
} }
Poll::Ready(Ok(buf.len()))
} }
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))), Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
} }
} }
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx) let this = self.get_mut();
while !this.tx_buf.is_empty() {
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.inner).poll_flush(cx)
} }
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx) let this = self.get_mut();
while !this.tx_buf.is_empty() {
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.inner).poll_shutdown(cx)
} }
} }

View File

@ -433,6 +433,8 @@ struct RealityStream<S> {
rx_nonce: u64, rx_nonce: u64,
tx_nonce: u64, tx_nonce: u64,
rx_buf: BytesMut, rx_buf: BytesMut,
plaintext_buf: BytesMut,
tx_buf: BytesMut,
} }
impl<S> RealityStream<S> { impl<S> RealityStream<S> {
@ -443,6 +445,8 @@ impl<S> RealityStream<S> {
rx_nonce: 0, rx_nonce: 0,
tx_nonce: 0, tx_nonce: 0,
rx_buf: BytesMut::with_capacity(16384), rx_buf: BytesMut::with_capacity(16384),
plaintext_buf: BytesMut::new(),
tx_buf: BytesMut::new(),
} }
} }
@ -456,6 +460,13 @@ impl<S> RealityStream<S> {
impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S> { impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> { fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
loop { loop {
if !self.plaintext_buf.is_empty() {
let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
buf.put_slice(&self.plaintext_buf[..out_len]);
self.plaintext_buf.advance(out_len);
return Poll::Ready(Ok(()));
}
if self.rx_buf.len() >= 5 { if self.rx_buf.len() >= 5 {
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize; let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
if self.rx_buf.len() >= 5 + len { if self.rx_buf.len() >= 5 + len {
@ -470,17 +481,16 @@ impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S>
match self.data_key.decrypt(nonce, ciphertext) { match self.data_key.decrypt(nonce, ciphertext) {
Ok(plaintext) => { Ok(plaintext) => {
self.rx_nonce += 1; self.rx_nonce += 1;
let out_len = std::cmp::min::<usize>(buf.remaining(), plaintext.len()); self.plaintext_buf.put_slice(&plaintext);
buf.put_slice(&plaintext[..out_len]);
self.rx_buf.advance(5 + len); self.rx_buf.advance(5 + len);
return Poll::Ready(Ok(())); continue;
} }
Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))), Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))),
} }
} }
} }
let mut read_buf = [0u8; 4096]; let mut read_buf = [0u8; 8192];
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf); let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) { match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
Poll::Ready(Ok(())) => { Poll::Ready(Ok(())) => {
@ -496,36 +506,59 @@ impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S>
impl<S: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for RealityStream<S> { impl<S: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for RealityStream<S> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> { fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
let nonce_bytes = Self::make_nonce(self.tx_nonce); let this = self.get_mut();
while !this.tx_buf.is_empty() {
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
let nonce_bytes = Self::make_nonce(this.tx_nonce);
let nonce = Nonce::from_slice(&nonce_bytes); let nonce = Nonce::from_slice(&nonce_bytes);
match self.data_key.encrypt(nonce, buf) { match this.data_key.encrypt(nonce, buf) {
Ok(ciphertext) => { Ok(ciphertext) => {
let mut record: BytesMut = BytesMut::with_capacity(5 + ciphertext.len()); this.tx_nonce += 1;
record.put_u8(0x17); this.tx_buf.reserve(5 + ciphertext.len());
record.put_u16(0x0303); this.tx_buf.put_u8(0x17);
record.put_u16(ciphertext.len() as u16); this.tx_buf.put_u16(0x0303);
record.put_slice(&ciphertext); this.tx_buf.put_u16(ciphertext.len() as u16);
this.tx_buf.put_slice(&ciphertext);
match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) { match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) if n == record.len() => { Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
self.tx_nonce += 1; Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Ready(Ok(buf.len())) Poll::Pending => {}
}
Poll::Ready(Ok(_n)) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported"))),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
} }
Poll::Ready(Ok(buf.len()))
} }
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))), Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
} }
} }
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> { fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx) let this = self.get_mut();
while !this.tx_buf.is_empty() {
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.inner).poll_flush(cx)
} }
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> { fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx) let this = self.get_mut();
while !this.tx_buf.is_empty() {
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Pin::new(&mut this.inner).poll_shutdown(cx)
} }
} }