mirror of https://github.com/ospab/ostp.git
fix: correctly handle payload buffering during http handshake in uot
This commit is contained in:
parent
1c98bf9a51
commit
a81625d721
|
|
@ -106,36 +106,56 @@ pub async fn connect_xhttp(
|
|||
tls_stream.write_all(req.as_bytes()).await?;
|
||||
tls_stream.flush().await?;
|
||||
|
||||
let mut buf = [0u8; 1024];
|
||||
let n = tls_stream.read(&mut buf).await?;
|
||||
let resp = String::from_utf8_lossy(&buf[..n]);
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let mut header_len = 0;
|
||||
loop {
|
||||
let n = tls_stream.read(&mut buf[header_len..]).await?;
|
||||
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
|
||||
header_len += n;
|
||||
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
|
||||
}
|
||||
let resp = String::from_utf8_lossy(&buf[..header_len]);
|
||||
if !resp.contains("200 OK") {
|
||||
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
|
||||
}
|
||||
|
||||
// Extract leftover payload if any
|
||||
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
|
||||
let leftover = buf[headers_end..header_len].to_vec();
|
||||
|
||||
// Split stream
|
||||
let (rx, tx) = tokio::io::split(tls_stream);
|
||||
start_uot_loops(rx, tx)
|
||||
start_uot_loops(rx, tx, leftover)
|
||||
} else {
|
||||
let mut tcp_stream = tcp_stream;
|
||||
tcp_stream.write_all(req.as_bytes()).await?;
|
||||
tcp_stream.flush().await?;
|
||||
|
||||
let mut buf = [0u8; 1024];
|
||||
let n = tcp_stream.read(&mut buf).await?;
|
||||
let resp = String::from_utf8_lossy(&buf[..n]);
|
||||
let mut buf = vec![0u8; 4096];
|
||||
let mut header_len = 0;
|
||||
loop {
|
||||
let n = tcp_stream.read(&mut buf[header_len..]).await?;
|
||||
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
|
||||
header_len += n;
|
||||
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
|
||||
}
|
||||
let resp = String::from_utf8_lossy(&buf[..header_len]);
|
||||
if !resp.contains("200 OK") {
|
||||
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
|
||||
}
|
||||
|
||||
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
|
||||
let leftover = buf[headers_end..header_len].to_vec();
|
||||
|
||||
let (rx, tx) = tcp_stream.into_split();
|
||||
start_uot_loops(rx, tx)
|
||||
start_uot_loops(rx, tx, leftover)
|
||||
}
|
||||
}
|
||||
|
||||
fn start_uot_loops<R, W>(
|
||||
mut net_rx: R,
|
||||
mut net_tx: W
|
||||
mut net_tx: W,
|
||||
leftover: Vec<u8>
|
||||
) -> Result<(mpsc::Sender<Bytes>, Arc<tokio::sync::Mutex<mpsc::Receiver<Bytes>>>)>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||
|
|
@ -156,16 +176,28 @@ where
|
|||
|
||||
// RX Loop (Network -> UoT -> App)
|
||||
tokio::spawn(async move {
|
||||
let mut buffer = BytesMut::from(&leftover[..]);
|
||||
loop {
|
||||
let len = match net_rx.read_u16().await {
|
||||
Ok(l) => l,
|
||||
Err(_) => break,
|
||||
};
|
||||
let mut buf = vec![0u8; len as usize];
|
||||
if net_rx.read_exact(&mut buf).await.is_err() {
|
||||
break;
|
||||
// Read more data if buffer has less than 2 bytes
|
||||
while buffer.len() < 2 {
|
||||
let mut temp = [0u8; 1024];
|
||||
match net_rx.read(&mut temp).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||
}
|
||||
}
|
||||
if app_tx.send(Bytes::from(buf)).await.is_err() {
|
||||
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
|
||||
|
||||
while buffer.len() < 2 + len {
|
||||
let mut temp = [0u8; 1024];
|
||||
match net_rx.read(&mut temp).await {
|
||||
Ok(0) | Err(_) => return,
|
||||
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||
}
|
||||
}
|
||||
|
||||
let packet = buffer.split_to(2 + len);
|
||||
if app_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -119,6 +119,9 @@ pub async fn handle_tcp_connection(
|
|||
tcp_map.write().await.insert(peer_addr, tx);
|
||||
}
|
||||
|
||||
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
|
||||
let leftover = &buf[headers_end..header_len];
|
||||
|
||||
// Process streams
|
||||
let (mut read_half, mut write_half) = stream.into_split();
|
||||
|
||||
|
|
@ -139,17 +142,36 @@ pub async fn handle_tcp_connection(
|
|||
});
|
||||
|
||||
// Reader loop
|
||||
let mut len_buf = [0u8; 2];
|
||||
let mut buffer = BytesMut::from(leftover);
|
||||
loop {
|
||||
if read_half.read_exact(&mut len_buf).await.is_err() {
|
||||
break;
|
||||
while buffer.len() < 2 {
|
||||
let mut temp = [0u8; 1024];
|
||||
match read_half.read(&mut temp).await {
|
||||
Ok(0) | Err(_) => {
|
||||
writer_task.abort();
|
||||
tcp_map.write().await.remove(&peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||
}
|
||||
}
|
||||
let len = u16::from_be_bytes(len_buf) as usize;
|
||||
let mut packet_buf = vec![0u8; len];
|
||||
if read_half.read_exact(&mut packet_buf).await.is_err() {
|
||||
break;
|
||||
|
||||
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
|
||||
|
||||
while buffer.len() < 2 + len {
|
||||
let mut temp = [0u8; 1024];
|
||||
match read_half.read(&mut temp).await {
|
||||
Ok(0) | Err(_) => {
|
||||
writer_task.abort();
|
||||
tcp_map.write().await.remove(&peer_addr);
|
||||
return Ok(());
|
||||
}
|
||||
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||
}
|
||||
}
|
||||
if udp_tx.send((Bytes::from(packet_buf), peer_addr)).await.is_err() {
|
||||
|
||||
let packet = buffer.split_to(2 + len);
|
||||
if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue