From 902e762c915fc734da313ed465481d8fbc770985 Mon Sep 17 00:00:00 2001 From: ospab Date: Sat, 30 May 2026 01:10:29 +0300 Subject: [PATCH] fix(xhttp): rewrite RealityStream buffering to prevent packet drops and data loss --- ostp-client/src/transport/xhttp.rs | 92 +++++++++++-------- ...otlin-compiler-15190969065085357994.salive | 0 ostp-server/src/transport/uot.rs | 75 ++++++++++----- 3 files changed, 109 insertions(+), 58 deletions(-) create mode 100644 ostp-flutter/android/.kotlin/sessions/kotlin-compiler-15190969065085357994.salive diff --git a/ostp-client/src/transport/xhttp.rs b/ostp-client/src/transport/xhttp.rs index c5b7514..a01aa57 100644 --- a/ostp-client/src/transport/xhttp.rs +++ b/ostp-client/src/transport/xhttp.rs @@ -108,6 +108,8 @@ struct RealityStream { rx_nonce: u64, tx_nonce: u64, rx_buf: BytesMut, + plaintext_buf: BytesMut, + tx_buf: BytesMut, } impl RealityStream { @@ -118,6 +120,8 @@ impl RealityStream { rx_nonce: 0, tx_nonce: 0, rx_buf: BytesMut::with_capacity(16384), + plaintext_buf: BytesMut::new(), + tx_buf: BytesMut::new(), } } @@ -131,11 +135,16 @@ impl RealityStream { impl tokio::io::AsyncRead for RealityStream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { - // Try to decode a full record + if !self.plaintext_buf.is_empty() { + let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len()); + buf.put_slice(&self.plaintext_buf[..out_len]); + self.plaintext_buf.advance(out_len); + return Poll::Ready(Ok(())); + } + if self.rx_buf.len() >= 5 { let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize; if self.rx_buf.len() >= 5 + len { - // We have a full record if self.rx_buf[0] != 0x17 { return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record"))); } @@ -147,17 +156,9 @@ impl tokio::io::AsyncRead for RealityStream { match self.data_key.decrypt(nonce, ciphertext) { Ok(plaintext) => { self.rx_nonce += 1; - let out_len = std::cmp::min(buf.remaining(), plaintext.len()); - buf.put_slice(&plaintext[..out_len]); - - if out_len < plaintext.len() { - // RealityStream doesn't buffer remaining plaintext if user buffer is too small. - // In xhttp_handshake_and_loop we always use 65535 byte buffers, so it fits. - // If needed, we'd add an internal plaintext_buffer. - } - + self.plaintext_buf.put_slice(&plaintext); self.rx_buf.advance(5 + len); - return Poll::Ready(Ok(())); + continue; } Err(_) => { return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))); @@ -166,13 +167,12 @@ impl tokio::io::AsyncRead for RealityStream { } } - // Need more data - let mut read_buf = [0u8; 4096]; + let mut read_buf = [0u8; 8192]; let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf); match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) { Poll::Ready(Ok(())) => { if tokio_buf.filled().is_empty() { - return Poll::Ready(Ok(())); // EOF + return Poll::Ready(Ok(())); } self.rx_buf.put_slice(tokio_buf.filled()); } @@ -185,42 +185,60 @@ impl tokio::io::AsyncRead for RealityStream { impl tokio::io::AsyncWrite for RealityStream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll> { - let nonce_bytes = Self::make_nonce(self.tx_nonce); + let this = self.get_mut(); + while !this.tx_buf.is_empty() { + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + let nonce_bytes = Self::make_nonce(this.tx_nonce); let nonce = Nonce::from_slice(&nonce_bytes); - // Encrypt the entire buf as a single record - match self.data_key.encrypt(nonce, buf) { + match this.data_key.encrypt(nonce, buf) { Ok(ciphertext) => { - let mut record = BytesMut::with_capacity(5 + ciphertext.len()); - record.put_u8(0x17); // Application Data - record.put_u16(0x0303); // TLS 1.2/1.3 - record.put_u16(ciphertext.len() as u16); - record.put_slice(&ciphertext); + this.tx_nonce += 1; + this.tx_buf.reserve(5 + ciphertext.len()); + this.tx_buf.put_u8(0x17); + this.tx_buf.put_u16(0x0303); + this.tx_buf.put_u16(ciphertext.len() as u16); + this.tx_buf.put_slice(&ciphertext); - // Write the full record to the inner stream - match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) { - Poll::Ready(Ok(n)) if n == record.len() => { - self.tx_nonce += 1; - Poll::Ready(Ok(buf.len())) - } - Poll::Ready(Ok(_n)) => { - // Partial writes of a single TLS record are not supported by this simple wrapper - Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported"))) - } - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => {} } + Poll::Ready(Ok(buf.len())) } Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) + let this = self.get_mut(); + while !this.tx_buf.is_empty() { + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Pin::new(&mut this.inner).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) + let this = self.get_mut(); + while !this.tx_buf.is_empty() { + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Pin::new(&mut this.inner).poll_shutdown(cx) } } diff --git a/ostp-flutter/android/.kotlin/sessions/kotlin-compiler-15190969065085357994.salive b/ostp-flutter/android/.kotlin/sessions/kotlin-compiler-15190969065085357994.salive new file mode 100644 index 0000000..e69de29 diff --git a/ostp-server/src/transport/uot.rs b/ostp-server/src/transport/uot.rs index 1261f81..10b4f32 100644 --- a/ostp-server/src/transport/uot.rs +++ b/ostp-server/src/transport/uot.rs @@ -433,6 +433,8 @@ struct RealityStream { rx_nonce: u64, tx_nonce: u64, rx_buf: BytesMut, + plaintext_buf: BytesMut, + tx_buf: BytesMut, } impl RealityStream { @@ -443,6 +445,8 @@ impl RealityStream { rx_nonce: 0, tx_nonce: 0, rx_buf: BytesMut::with_capacity(16384), + plaintext_buf: BytesMut::new(), + tx_buf: BytesMut::new(), } } @@ -456,6 +460,13 @@ impl RealityStream { impl tokio::io::AsyncRead for RealityStream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll> { loop { + if !self.plaintext_buf.is_empty() { + let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len()); + buf.put_slice(&self.plaintext_buf[..out_len]); + self.plaintext_buf.advance(out_len); + return Poll::Ready(Ok(())); + } + if self.rx_buf.len() >= 5 { let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize; if self.rx_buf.len() >= 5 + len { @@ -470,17 +481,16 @@ impl tokio::io::AsyncRead for RealityStream match self.data_key.decrypt(nonce, ciphertext) { Ok(plaintext) => { self.rx_nonce += 1; - let out_len = std::cmp::min::(buf.remaining(), plaintext.len()); - buf.put_slice(&plaintext[..out_len]); + self.plaintext_buf.put_slice(&plaintext); self.rx_buf.advance(5 + len); - return Poll::Ready(Ok(())); + continue; } Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))), } } } - let mut read_buf = [0u8; 4096]; + let mut read_buf = [0u8; 8192]; let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf); match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) { Poll::Ready(Ok(())) => { @@ -496,36 +506,59 @@ impl tokio::io::AsyncRead for RealityStream impl tokio::io::AsyncWrite for RealityStream { fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll> { - let nonce_bytes = Self::make_nonce(self.tx_nonce); + let this = self.get_mut(); + while !this.tx_buf.is_empty() { + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + + let nonce_bytes = Self::make_nonce(this.tx_nonce); let nonce = Nonce::from_slice(&nonce_bytes); - match self.data_key.encrypt(nonce, buf) { + match this.data_key.encrypt(nonce, buf) { Ok(ciphertext) => { - let mut record: BytesMut = BytesMut::with_capacity(5 + ciphertext.len()); - record.put_u8(0x17); - record.put_u16(0x0303); - record.put_u16(ciphertext.len() as u16); - record.put_slice(&ciphertext); + this.tx_nonce += 1; + this.tx_buf.reserve(5 + ciphertext.len()); + this.tx_buf.put_u8(0x17); + this.tx_buf.put_u16(0x0303); + this.tx_buf.put_u16(ciphertext.len() as u16); + this.tx_buf.put_slice(&ciphertext); - match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) { - Poll::Ready(Ok(n)) if n == record.len() => { - self.tx_nonce += 1; - Poll::Ready(Ok(buf.len())) - } - Poll::Ready(Ok(_n)) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported"))), - Poll::Ready(Err(e)) => Poll::Ready(Err(e)), - Poll::Pending => Poll::Pending, + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => {} } + Poll::Ready(Ok(buf.len())) } Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) + let this = self.get_mut(); + while !this.tx_buf.is_empty() { + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Pin::new(&mut this.inner).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_shutdown(cx) + let this = self.get_mut(); + while !this.tx_buf.is_empty() { + match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) { + Poll::Ready(Ok(n)) => this.tx_buf.advance(n), + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Pin::new(&mut this.inner).poll_shutdown(cx) } }