fix: correctly handle payload buffering during http handshake in uot

This commit is contained in:
ospab 2026-05-21 12:43:47 +03:00
parent 1c98bf9a51
commit a81625d721
2 changed files with 79 additions and 25 deletions

View File

@ -106,36 +106,56 @@ pub async fn connect_xhttp(
tls_stream.write_all(req.as_bytes()).await?; tls_stream.write_all(req.as_bytes()).await?;
tls_stream.flush().await?; tls_stream.flush().await?;
let mut buf = [0u8; 1024]; let mut buf = vec![0u8; 4096];
let n = tls_stream.read(&mut buf).await?; let mut header_len = 0;
let resp = String::from_utf8_lossy(&buf[..n]); loop {
let n = tls_stream.read(&mut buf[header_len..]).await?;
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
header_len += n;
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
}
let resp = String::from_utf8_lossy(&buf[..header_len]);
if !resp.contains("200 OK") { if !resp.contains("200 OK") {
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or("")); anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
} }
// Extract leftover payload if any
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let leftover = buf[headers_end..header_len].to_vec();
// Split stream // Split stream
let (rx, tx) = tokio::io::split(tls_stream); let (rx, tx) = tokio::io::split(tls_stream);
start_uot_loops(rx, tx) start_uot_loops(rx, tx, leftover)
} else { } else {
let mut tcp_stream = tcp_stream; let mut tcp_stream = tcp_stream;
tcp_stream.write_all(req.as_bytes()).await?; tcp_stream.write_all(req.as_bytes()).await?;
tcp_stream.flush().await?; tcp_stream.flush().await?;
let mut buf = [0u8; 1024]; let mut buf = vec![0u8; 4096];
let n = tcp_stream.read(&mut buf).await?; let mut header_len = 0;
let resp = String::from_utf8_lossy(&buf[..n]); loop {
let n = tcp_stream.read(&mut buf[header_len..]).await?;
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
header_len += n;
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
}
let resp = String::from_utf8_lossy(&buf[..header_len]);
if !resp.contains("200 OK") { if !resp.contains("200 OK") {
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or("")); anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
} }
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let leftover = buf[headers_end..header_len].to_vec();
let (rx, tx) = tcp_stream.into_split(); let (rx, tx) = tcp_stream.into_split();
start_uot_loops(rx, tx) start_uot_loops(rx, tx, leftover)
} }
} }
fn start_uot_loops<R, W>( fn start_uot_loops<R, W>(
mut net_rx: R, mut net_rx: R,
mut net_tx: W mut net_tx: W,
leftover: Vec<u8>
) -> Result<(mpsc::Sender<Bytes>, Arc<tokio::sync::Mutex<mpsc::Receiver<Bytes>>>)> ) -> Result<(mpsc::Sender<Bytes>, Arc<tokio::sync::Mutex<mpsc::Receiver<Bytes>>>)>
where where
R: tokio::io::AsyncRead + Unpin + Send + 'static, R: tokio::io::AsyncRead + Unpin + Send + 'static,
@ -156,16 +176,28 @@ where
// RX Loop (Network -> UoT -> App) // RX Loop (Network -> UoT -> App)
tokio::spawn(async move { tokio::spawn(async move {
let mut buffer = BytesMut::from(&leftover[..]);
loop { loop {
let len = match net_rx.read_u16().await { // Read more data if buffer has less than 2 bytes
Ok(l) => l, while buffer.len() < 2 {
Err(_) => break, let mut temp = [0u8; 1024];
}; match net_rx.read(&mut temp).await {
let mut buf = vec![0u8; len as usize]; Ok(0) | Err(_) => return,
if net_rx.read_exact(&mut buf).await.is_err() { Ok(n) => buffer.extend_from_slice(&temp[..n]),
break; }
} }
if app_tx.send(Bytes::from(buf)).await.is_err() { let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
while buffer.len() < 2 + len {
let mut temp = [0u8; 1024];
match net_rx.read(&mut temp).await {
Ok(0) | Err(_) => return,
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
}
let packet = buffer.split_to(2 + len);
if app_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() {
break; break;
} }
} }

View File

@ -119,6 +119,9 @@ pub async fn handle_tcp_connection(
tcp_map.write().await.insert(peer_addr, tx); tcp_map.write().await.insert(peer_addr, tx);
} }
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let leftover = &buf[headers_end..header_len];
// Process streams // Process streams
let (mut read_half, mut write_half) = stream.into_split(); let (mut read_half, mut write_half) = stream.into_split();
@ -139,17 +142,36 @@ pub async fn handle_tcp_connection(
}); });
// Reader loop // Reader loop
let mut len_buf = [0u8; 2]; let mut buffer = BytesMut::from(leftover);
loop { loop {
if read_half.read_exact(&mut len_buf).await.is_err() { while buffer.len() < 2 {
break; let mut temp = [0u8; 1024];
match read_half.read(&mut temp).await {
Ok(0) | Err(_) => {
writer_task.abort();
tcp_map.write().await.remove(&peer_addr);
return Ok(());
}
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
} }
let len = u16::from_be_bytes(len_buf) as usize;
let mut packet_buf = vec![0u8; len]; let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
if read_half.read_exact(&mut packet_buf).await.is_err() {
break; while buffer.len() < 2 + len {
let mut temp = [0u8; 1024];
match read_half.read(&mut temp).await {
Ok(0) | Err(_) => {
writer_task.abort();
tcp_map.write().await.remove(&peer_addr);
return Ok(());
}
Ok(n) => buffer.extend_from_slice(&temp[..n]),
}
} }
if udp_tx.send((Bytes::from(packet_buf), peer_addr)).await.is_err() {
let packet = buffer.split_to(2 + len);
if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() {
break; break;
} }
} }