mirror of https://github.com/ospab/ostp.git
fix(xhttp): rewrite RealityStream buffering to prevent packet drops and data loss
This commit is contained in:
parent
7257da174a
commit
902e762c91
|
|
@ -108,6 +108,8 @@ struct RealityStream {
|
|||
rx_nonce: u64,
|
||||
tx_nonce: u64,
|
||||
rx_buf: BytesMut,
|
||||
plaintext_buf: BytesMut,
|
||||
tx_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl RealityStream {
|
||||
|
|
@ -118,6 +120,8 @@ impl RealityStream {
|
|||
rx_nonce: 0,
|
||||
tx_nonce: 0,
|
||||
rx_buf: BytesMut::with_capacity(16384),
|
||||
plaintext_buf: BytesMut::new(),
|
||||
tx_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -131,11 +135,16 @@ impl RealityStream {
|
|||
impl tokio::io::AsyncRead for RealityStream {
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||
loop {
|
||||
// Try to decode a full record
|
||||
if !self.plaintext_buf.is_empty() {
|
||||
let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
|
||||
buf.put_slice(&self.plaintext_buf[..out_len]);
|
||||
self.plaintext_buf.advance(out_len);
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
if self.rx_buf.len() >= 5 {
|
||||
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
|
||||
if self.rx_buf.len() >= 5 + len {
|
||||
// We have a full record
|
||||
if self.rx_buf[0] != 0x17 {
|
||||
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record")));
|
||||
}
|
||||
|
|
@ -147,17 +156,9 @@ impl tokio::io::AsyncRead for RealityStream {
|
|||
match self.data_key.decrypt(nonce, ciphertext) {
|
||||
Ok(plaintext) => {
|
||||
self.rx_nonce += 1;
|
||||
let out_len = std::cmp::min(buf.remaining(), plaintext.len());
|
||||
buf.put_slice(&plaintext[..out_len]);
|
||||
|
||||
if out_len < plaintext.len() {
|
||||
// RealityStream doesn't buffer remaining plaintext if user buffer is too small.
|
||||
// In xhttp_handshake_and_loop we always use 65535 byte buffers, so it fits.
|
||||
// If needed, we'd add an internal plaintext_buffer.
|
||||
}
|
||||
|
||||
self.plaintext_buf.put_slice(&plaintext);
|
||||
self.rx_buf.advance(5 + len);
|
||||
return Poll::Ready(Ok(()));
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed")));
|
||||
|
|
@ -166,13 +167,12 @@ impl tokio::io::AsyncRead for RealityStream {
|
|||
}
|
||||
}
|
||||
|
||||
// Need more data
|
||||
let mut read_buf = [0u8; 4096];
|
||||
let mut read_buf = [0u8; 8192];
|
||||
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
|
||||
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
if tokio_buf.filled().is_empty() {
|
||||
return Poll::Ready(Ok(())); // EOF
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
self.rx_buf.put_slice(tokio_buf.filled());
|
||||
}
|
||||
|
|
@ -185,42 +185,60 @@ impl tokio::io::AsyncRead for RealityStream {
|
|||
|
||||
impl tokio::io::AsyncWrite for RealityStream {
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
|
||||
let nonce_bytes = Self::make_nonce(self.tx_nonce);
|
||||
let this = self.get_mut();
|
||||
while !this.tx_buf.is_empty() {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
let nonce_bytes = Self::make_nonce(this.tx_nonce);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
// Encrypt the entire buf as a single record
|
||||
match self.data_key.encrypt(nonce, buf) {
|
||||
match this.data_key.encrypt(nonce, buf) {
|
||||
Ok(ciphertext) => {
|
||||
let mut record = BytesMut::with_capacity(5 + ciphertext.len());
|
||||
record.put_u8(0x17); // Application Data
|
||||
record.put_u16(0x0303); // TLS 1.2/1.3
|
||||
record.put_u16(ciphertext.len() as u16);
|
||||
record.put_slice(&ciphertext);
|
||||
this.tx_nonce += 1;
|
||||
this.tx_buf.reserve(5 + ciphertext.len());
|
||||
this.tx_buf.put_u8(0x17);
|
||||
this.tx_buf.put_u16(0x0303);
|
||||
this.tx_buf.put_u16(ciphertext.len() as u16);
|
||||
this.tx_buf.put_slice(&ciphertext);
|
||||
|
||||
// Write the full record to the inner stream
|
||||
match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) {
|
||||
Poll::Ready(Ok(n)) if n == record.len() => {
|
||||
self.tx_nonce += 1;
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
Poll::Ready(Ok(_n)) => {
|
||||
// Partial writes of a single TLS record are not supported by this simple wrapper
|
||||
Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported")))
|
||||
}
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
let this = self.get_mut();
|
||||
while !this.tx_buf.is_empty() {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Pin::new(&mut this.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
let this = self.get_mut();
|
||||
while !this.tx_buf.is_empty() {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Pin::new(&mut this.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -433,6 +433,8 @@ struct RealityStream<S> {
|
|||
rx_nonce: u64,
|
||||
tx_nonce: u64,
|
||||
rx_buf: BytesMut,
|
||||
plaintext_buf: BytesMut,
|
||||
tx_buf: BytesMut,
|
||||
}
|
||||
|
||||
impl<S> RealityStream<S> {
|
||||
|
|
@ -443,6 +445,8 @@ impl<S> RealityStream<S> {
|
|||
rx_nonce: 0,
|
||||
tx_nonce: 0,
|
||||
rx_buf: BytesMut::with_capacity(16384),
|
||||
plaintext_buf: BytesMut::new(),
|
||||
tx_buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -456,6 +460,13 @@ impl<S> RealityStream<S> {
|
|||
impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S> {
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||
loop {
|
||||
if !self.plaintext_buf.is_empty() {
|
||||
let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
|
||||
buf.put_slice(&self.plaintext_buf[..out_len]);
|
||||
self.plaintext_buf.advance(out_len);
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
if self.rx_buf.len() >= 5 {
|
||||
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
|
||||
if self.rx_buf.len() >= 5 + len {
|
||||
|
|
@ -470,17 +481,16 @@ impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S>
|
|||
match self.data_key.decrypt(nonce, ciphertext) {
|
||||
Ok(plaintext) => {
|
||||
self.rx_nonce += 1;
|
||||
let out_len = std::cmp::min::<usize>(buf.remaining(), plaintext.len());
|
||||
buf.put_slice(&plaintext[..out_len]);
|
||||
self.plaintext_buf.put_slice(&plaintext);
|
||||
self.rx_buf.advance(5 + len);
|
||||
return Poll::Ready(Ok(()));
|
||||
continue;
|
||||
}
|
||||
Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut read_buf = [0u8; 4096];
|
||||
let mut read_buf = [0u8; 8192];
|
||||
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
|
||||
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
|
|
@ -496,36 +506,59 @@ impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S>
|
|||
|
||||
impl<S: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for RealityStream<S> {
|
||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
|
||||
let nonce_bytes = Self::make_nonce(self.tx_nonce);
|
||||
let this = self.get_mut();
|
||||
while !this.tx_buf.is_empty() {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
let nonce_bytes = Self::make_nonce(this.tx_nonce);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
|
||||
match self.data_key.encrypt(nonce, buf) {
|
||||
match this.data_key.encrypt(nonce, buf) {
|
||||
Ok(ciphertext) => {
|
||||
let mut record: BytesMut = BytesMut::with_capacity(5 + ciphertext.len());
|
||||
record.put_u8(0x17);
|
||||
record.put_u16(0x0303);
|
||||
record.put_u16(ciphertext.len() as u16);
|
||||
record.put_slice(&ciphertext);
|
||||
this.tx_nonce += 1;
|
||||
this.tx_buf.reserve(5 + ciphertext.len());
|
||||
this.tx_buf.put_u8(0x17);
|
||||
this.tx_buf.put_u16(0x0303);
|
||||
this.tx_buf.put_u16(ciphertext.len() as u16);
|
||||
this.tx_buf.put_slice(&ciphertext);
|
||||
|
||||
match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) {
|
||||
Poll::Ready(Ok(n)) if n == record.len() => {
|
||||
self.tx_nonce += 1;
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
Poll::Ready(Ok(_n)) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported"))),
|
||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
let this = self.get_mut();
|
||||
while !this.tx_buf.is_empty() {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Pin::new(&mut this.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
let this = self.get_mut();
|
||||
while !this.tx_buf.is_empty() {
|
||||
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Pin::new(&mut this.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue