fix: correctly handle payload buffering during http handshake in uot

This commit is contained in:
ospab 2026-05-21 12:43:47 +03:00
parent 1c98bf9a51
commit a81625d721
2 changed files with 79 additions and 25 deletions

View File

@ -106,36 +106,56 @@ pub async fn connect_xhttp(
tls_stream.write_all(req.as_bytes()).await?;
tls_stream.flush().await?;
let mut buf = [0u8; 1024];
let n = tls_stream.read(&mut buf).await?;
let resp = String::from_utf8_lossy(&buf[..n]);
let mut buf = vec![0u8; 4096];
let mut header_len = 0;
loop {
let n = tls_stream.read(&mut buf[header_len..]).await?;
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
header_len += n;
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
}
let resp = String::from_utf8_lossy(&buf[..header_len]);
if !resp.contains("200 OK") {
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
}
// Extract leftover payload if any
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let leftover = buf[headers_end..header_len].to_vec();
// Split stream
let (rx, tx) = tokio::io::split(tls_stream);
start_uot_loops(rx, tx)
start_uot_loops(rx, tx, leftover)
} else {
let mut tcp_stream = tcp_stream;
tcp_stream.write_all(req.as_bytes()).await?;
tcp_stream.flush().await?;
let mut buf = [0u8; 1024];
let n = tcp_stream.read(&mut buf).await?;
let resp = String::from_utf8_lossy(&buf[..n]);
let mut buf = vec![0u8; 4096];
let mut header_len = 0;
loop {
let n = tcp_stream.read(&mut buf[header_len..]).await?;
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
header_len += n;
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
}
let resp = String::from_utf8_lossy(&buf[..header_len]);
if !resp.contains("200 OK") {
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
}
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let leftover = buf[headers_end..header_len].to_vec();
let (rx, tx) = tcp_stream.into_split();
start_uot_loops(rx, tx)
start_uot_loops(rx, tx, leftover)
}
}
fn start_uot_loops<R, W>(
mut net_rx: R,
mut net_tx: W
mut net_tx: W,
leftover: Vec<u8>
) -> Result<(mpsc::Sender<Bytes>, Arc<tokio::sync::Mutex<mpsc::Receiver<Bytes>>>)>
where
R: tokio::io::AsyncRead + Unpin + Send + 'static,
@ -156,16 +176,28 @@ where
// RX Loop (Network -> UoT -> App)
tokio::spawn(async move {
let mut buffer = BytesMut::from(&leftover[..]);
loop {
let len = match net_rx.read_u16().await {
Ok(l) => l,
Err(_) => break,
};
let mut buf = vec![0u8; len as usize];
if net_rx.read_exact(&mut buf).await.is_err() {
break;
// Read more data if buffer has less than 2 bytes
while buffer.len() < 2 {
let mut temp = [0u8; 1024];
match net_rx.read(&mut temp).await {
Ok(0) | Err(_) => return,
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
}
if app_tx.send(Bytes::from(buf)).await.is_err() {
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
while buffer.len() < 2 + len {
let mut temp = [0u8; 1024];
match net_rx.read(&mut temp).await {
Ok(0) | Err(_) => return,
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
}
let packet = buffer.split_to(2 + len);
if app_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() {
break;
}
}

View File

@ -119,6 +119,9 @@ pub async fn handle_tcp_connection(
tcp_map.write().await.insert(peer_addr, tx);
}
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let leftover = &buf[headers_end..header_len];
// Process streams
let (mut read_half, mut write_half) = stream.into_split();
@ -139,17 +142,36 @@ pub async fn handle_tcp_connection(
});
// Reader loop
let mut len_buf = [0u8; 2];
let mut buffer = BytesMut::from(leftover);
loop {
if read_half.read_exact(&mut len_buf).await.is_err() {
break;
while buffer.len() < 2 {
let mut temp = [0u8; 1024];
match read_half.read(&mut temp).await {
Ok(0) | Err(_) => {
writer_task.abort();
tcp_map.write().await.remove(&peer_addr);
return Ok(());
}
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
}
let len = u16::from_be_bytes(len_buf) as usize;
let mut packet_buf = vec![0u8; len];
if read_half.read_exact(&mut packet_buf).await.is_err() {
break;
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
while buffer.len() < 2 + len {
let mut temp = [0u8; 1024];
match read_half.read(&mut temp).await {
Ok(0) | Err(_) => {
writer_task.abort();
tcp_map.write().await.remove(&peer_addr);
return Ok(());
}
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
}
if udp_tx.send((Bytes::from(packet_buf), peer_addr)).await.is_err() {
let packet = buffer.split_to(2 + len);
if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() {
break;
}
}