mirror of https://github.com/ospab/ostp.git
fix(xhttp): rewrite RealityStream buffering to prevent packet drops and data loss
This commit is contained in:
parent
7257da174a
commit
902e762c91
|
|
@ -108,6 +108,8 @@ struct RealityStream {
|
||||||
rx_nonce: u64,
|
rx_nonce: u64,
|
||||||
tx_nonce: u64,
|
tx_nonce: u64,
|
||||||
rx_buf: BytesMut,
|
rx_buf: BytesMut,
|
||||||
|
plaintext_buf: BytesMut,
|
||||||
|
tx_buf: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RealityStream {
|
impl RealityStream {
|
||||||
|
|
@ -118,6 +120,8 @@ impl RealityStream {
|
||||||
rx_nonce: 0,
|
rx_nonce: 0,
|
||||||
tx_nonce: 0,
|
tx_nonce: 0,
|
||||||
rx_buf: BytesMut::with_capacity(16384),
|
rx_buf: BytesMut::with_capacity(16384),
|
||||||
|
plaintext_buf: BytesMut::new(),
|
||||||
|
tx_buf: BytesMut::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -131,11 +135,16 @@ impl RealityStream {
|
||||||
impl tokio::io::AsyncRead for RealityStream {
|
impl tokio::io::AsyncRead for RealityStream {
|
||||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
loop {
|
loop {
|
||||||
// Try to decode a full record
|
if !self.plaintext_buf.is_empty() {
|
||||||
|
let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
|
||||||
|
buf.put_slice(&self.plaintext_buf[..out_len]);
|
||||||
|
self.plaintext_buf.advance(out_len);
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
if self.rx_buf.len() >= 5 {
|
if self.rx_buf.len() >= 5 {
|
||||||
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
|
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
|
||||||
if self.rx_buf.len() >= 5 + len {
|
if self.rx_buf.len() >= 5 + len {
|
||||||
// We have a full record
|
|
||||||
if self.rx_buf[0] != 0x17 {
|
if self.rx_buf[0] != 0x17 {
|
||||||
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record")));
|
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "expected application data record")));
|
||||||
}
|
}
|
||||||
|
|
@ -147,17 +156,9 @@ impl tokio::io::AsyncRead for RealityStream {
|
||||||
match self.data_key.decrypt(nonce, ciphertext) {
|
match self.data_key.decrypt(nonce, ciphertext) {
|
||||||
Ok(plaintext) => {
|
Ok(plaintext) => {
|
||||||
self.rx_nonce += 1;
|
self.rx_nonce += 1;
|
||||||
let out_len = std::cmp::min(buf.remaining(), plaintext.len());
|
self.plaintext_buf.put_slice(&plaintext);
|
||||||
buf.put_slice(&plaintext[..out_len]);
|
|
||||||
|
|
||||||
if out_len < plaintext.len() {
|
|
||||||
// RealityStream doesn't buffer remaining plaintext if user buffer is too small.
|
|
||||||
// In xhttp_handshake_and_loop we always use 65535 byte buffers, so it fits.
|
|
||||||
// If needed, we'd add an internal plaintext_buffer.
|
|
||||||
}
|
|
||||||
|
|
||||||
self.rx_buf.advance(5 + len);
|
self.rx_buf.advance(5 + len);
|
||||||
return Poll::Ready(Ok(()));
|
continue;
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed")));
|
return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed")));
|
||||||
|
|
@ -166,13 +167,12 @@ impl tokio::io::AsyncRead for RealityStream {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need more data
|
let mut read_buf = [0u8; 8192];
|
||||||
let mut read_buf = [0u8; 4096];
|
|
||||||
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
|
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
|
||||||
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
|
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
|
||||||
Poll::Ready(Ok(())) => {
|
Poll::Ready(Ok(())) => {
|
||||||
if tokio_buf.filled().is_empty() {
|
if tokio_buf.filled().is_empty() {
|
||||||
return Poll::Ready(Ok(())); // EOF
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
self.rx_buf.put_slice(tokio_buf.filled());
|
self.rx_buf.put_slice(tokio_buf.filled());
|
||||||
}
|
}
|
||||||
|
|
@ -185,42 +185,60 @@ impl tokio::io::AsyncRead for RealityStream {
|
||||||
|
|
||||||
impl tokio::io::AsyncWrite for RealityStream {
|
impl tokio::io::AsyncWrite for RealityStream {
|
||||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
|
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
|
||||||
let nonce_bytes = Self::make_nonce(self.tx_nonce);
|
let this = self.get_mut();
|
||||||
|
while !this.tx_buf.is_empty() {
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let nonce_bytes = Self::make_nonce(this.tx_nonce);
|
||||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||||
|
|
||||||
// Encrypt the entire buf as a single record
|
match this.data_key.encrypt(nonce, buf) {
|
||||||
match self.data_key.encrypt(nonce, buf) {
|
|
||||||
Ok(ciphertext) => {
|
Ok(ciphertext) => {
|
||||||
let mut record = BytesMut::with_capacity(5 + ciphertext.len());
|
this.tx_nonce += 1;
|
||||||
record.put_u8(0x17); // Application Data
|
this.tx_buf.reserve(5 + ciphertext.len());
|
||||||
record.put_u16(0x0303); // TLS 1.2/1.3
|
this.tx_buf.put_u8(0x17);
|
||||||
record.put_u16(ciphertext.len() as u16);
|
this.tx_buf.put_u16(0x0303);
|
||||||
record.put_slice(&ciphertext);
|
this.tx_buf.put_u16(ciphertext.len() as u16);
|
||||||
|
this.tx_buf.put_slice(&ciphertext);
|
||||||
|
|
||||||
// Write the full record to the inner stream
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) {
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
Poll::Ready(Ok(n)) if n == record.len() => {
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
self.tx_nonce += 1;
|
Poll::Pending => {}
|
||||||
|
}
|
||||||
Poll::Ready(Ok(buf.len()))
|
Poll::Ready(Ok(buf.len()))
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(_n)) => {
|
|
||||||
// Partial writes of a single TLS record are not supported by this simple wrapper
|
|
||||||
Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported")))
|
|
||||||
}
|
|
||||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
|
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||||
Pin::new(&mut self.inner).poll_flush(cx)
|
let this = self.get_mut();
|
||||||
|
while !this.tx_buf.is_empty() {
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pin::new(&mut this.inner).poll_flush(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
let this = self.get_mut();
|
||||||
|
while !this.tx_buf.is_empty() {
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pin::new(&mut this.inner).poll_shutdown(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -433,6 +433,8 @@ struct RealityStream<S> {
|
||||||
rx_nonce: u64,
|
rx_nonce: u64,
|
||||||
tx_nonce: u64,
|
tx_nonce: u64,
|
||||||
rx_buf: BytesMut,
|
rx_buf: BytesMut,
|
||||||
|
plaintext_buf: BytesMut,
|
||||||
|
tx_buf: BytesMut,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> RealityStream<S> {
|
impl<S> RealityStream<S> {
|
||||||
|
|
@ -443,6 +445,8 @@ impl<S> RealityStream<S> {
|
||||||
rx_nonce: 0,
|
rx_nonce: 0,
|
||||||
tx_nonce: 0,
|
tx_nonce: 0,
|
||||||
rx_buf: BytesMut::with_capacity(16384),
|
rx_buf: BytesMut::with_capacity(16384),
|
||||||
|
plaintext_buf: BytesMut::new(),
|
||||||
|
tx_buf: BytesMut::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -456,6 +460,13 @@ impl<S> RealityStream<S> {
|
||||||
impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S> {
|
impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S> {
|
||||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_read(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
|
||||||
loop {
|
loop {
|
||||||
|
if !self.plaintext_buf.is_empty() {
|
||||||
|
let out_len = std::cmp::min(buf.remaining(), self.plaintext_buf.len());
|
||||||
|
buf.put_slice(&self.plaintext_buf[..out_len]);
|
||||||
|
self.plaintext_buf.advance(out_len);
|
||||||
|
return Poll::Ready(Ok(()));
|
||||||
|
}
|
||||||
|
|
||||||
if self.rx_buf.len() >= 5 {
|
if self.rx_buf.len() >= 5 {
|
||||||
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
|
let len = u16::from_be_bytes([self.rx_buf[3], self.rx_buf[4]]) as usize;
|
||||||
if self.rx_buf.len() >= 5 + len {
|
if self.rx_buf.len() >= 5 + len {
|
||||||
|
|
@ -470,17 +481,16 @@ impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S>
|
||||||
match self.data_key.decrypt(nonce, ciphertext) {
|
match self.data_key.decrypt(nonce, ciphertext) {
|
||||||
Ok(plaintext) => {
|
Ok(plaintext) => {
|
||||||
self.rx_nonce += 1;
|
self.rx_nonce += 1;
|
||||||
let out_len = std::cmp::min::<usize>(buf.remaining(), plaintext.len());
|
self.plaintext_buf.put_slice(&plaintext);
|
||||||
buf.put_slice(&plaintext[..out_len]);
|
|
||||||
self.rx_buf.advance(5 + len);
|
self.rx_buf.advance(5 + len);
|
||||||
return Poll::Ready(Ok(()));
|
continue;
|
||||||
}
|
}
|
||||||
Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))),
|
Err(_) => return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "reality decrypt failed"))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut read_buf = [0u8; 4096];
|
let mut read_buf = [0u8; 8192];
|
||||||
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
|
let mut tokio_buf = tokio::io::ReadBuf::new(&mut read_buf);
|
||||||
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
|
match Pin::new(&mut self.inner).poll_read(cx, &mut tokio_buf) {
|
||||||
Poll::Ready(Ok(())) => {
|
Poll::Ready(Ok(())) => {
|
||||||
|
|
@ -496,36 +506,59 @@ impl<S: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for RealityStream<S>
|
||||||
|
|
||||||
impl<S: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for RealityStream<S> {
|
impl<S: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for RealityStream<S> {
|
||||||
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
|
fn poll_write(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
|
||||||
let nonce_bytes = Self::make_nonce(self.tx_nonce);
|
let this = self.get_mut();
|
||||||
|
while !this.tx_buf.is_empty() {
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let nonce_bytes = Self::make_nonce(this.tx_nonce);
|
||||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||||
|
|
||||||
match self.data_key.encrypt(nonce, buf) {
|
match this.data_key.encrypt(nonce, buf) {
|
||||||
Ok(ciphertext) => {
|
Ok(ciphertext) => {
|
||||||
let mut record: BytesMut = BytesMut::with_capacity(5 + ciphertext.len());
|
this.tx_nonce += 1;
|
||||||
record.put_u8(0x17);
|
this.tx_buf.reserve(5 + ciphertext.len());
|
||||||
record.put_u16(0x0303);
|
this.tx_buf.put_u8(0x17);
|
||||||
record.put_u16(ciphertext.len() as u16);
|
this.tx_buf.put_u16(0x0303);
|
||||||
record.put_slice(&ciphertext);
|
this.tx_buf.put_u16(ciphertext.len() as u16);
|
||||||
|
this.tx_buf.put_slice(&ciphertext);
|
||||||
|
|
||||||
match tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.inner), cx, &record) {
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
Poll::Ready(Ok(n)) if n == record.len() => {
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
self.tx_nonce += 1;
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => {}
|
||||||
|
}
|
||||||
Poll::Ready(Ok(buf.len()))
|
Poll::Ready(Ok(buf.len()))
|
||||||
}
|
}
|
||||||
Poll::Ready(Ok(_n)) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "partial write not supported"))),
|
|
||||||
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
|
|
||||||
Poll::Pending => Poll::Pending,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
|
Err(_) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, "reality encrypt failed"))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||||
Pin::new(&mut self.inner).poll_flush(cx)
|
let this = self.get_mut();
|
||||||
|
while !this.tx_buf.is_empty() {
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pin::new(&mut this.inner).poll_flush(cx)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<std::io::Result<()>> {
|
||||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
let this = self.get_mut();
|
||||||
|
while !this.tx_buf.is_empty() {
|
||||||
|
match Pin::new(&mut this.inner).poll_write(cx, &this.tx_buf) {
|
||||||
|
Poll::Ready(Ok(n)) => this.tx_buf.advance(n),
|
||||||
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Pin::new(&mut this.inner).poll_shutdown(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue