mirror of https://github.com/ospab/ostp.git
fix: correctly handle payload buffering during http handshake in uot
This commit is contained in:
parent
1c98bf9a51
commit
a81625d721
|
|
@ -106,36 +106,56 @@ pub async fn connect_xhttp(
|
||||||
tls_stream.write_all(req.as_bytes()).await?;
|
tls_stream.write_all(req.as_bytes()).await?;
|
||||||
tls_stream.flush().await?;
|
tls_stream.flush().await?;
|
||||||
|
|
||||||
let mut buf = [0u8; 1024];
|
let mut buf = vec![0u8; 4096];
|
||||||
let n = tls_stream.read(&mut buf).await?;
|
let mut header_len = 0;
|
||||||
let resp = String::from_utf8_lossy(&buf[..n]);
|
loop {
|
||||||
|
let n = tls_stream.read(&mut buf[header_len..]).await?;
|
||||||
|
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
|
||||||
|
header_len += n;
|
||||||
|
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
|
||||||
|
}
|
||||||
|
let resp = String::from_utf8_lossy(&buf[..header_len]);
|
||||||
if !resp.contains("200 OK") {
|
if !resp.contains("200 OK") {
|
||||||
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
|
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract leftover payload if any
|
||||||
|
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
|
||||||
|
let leftover = buf[headers_end..header_len].to_vec();
|
||||||
|
|
||||||
// Split stream
|
// Split stream
|
||||||
let (rx, tx) = tokio::io::split(tls_stream);
|
let (rx, tx) = tokio::io::split(tls_stream);
|
||||||
start_uot_loops(rx, tx)
|
start_uot_loops(rx, tx, leftover)
|
||||||
} else {
|
} else {
|
||||||
let mut tcp_stream = tcp_stream;
|
let mut tcp_stream = tcp_stream;
|
||||||
tcp_stream.write_all(req.as_bytes()).await?;
|
tcp_stream.write_all(req.as_bytes()).await?;
|
||||||
tcp_stream.flush().await?;
|
tcp_stream.flush().await?;
|
||||||
|
|
||||||
let mut buf = [0u8; 1024];
|
let mut buf = vec![0u8; 4096];
|
||||||
let n = tcp_stream.read(&mut buf).await?;
|
let mut header_len = 0;
|
||||||
let resp = String::from_utf8_lossy(&buf[..n]);
|
loop {
|
||||||
|
let n = tcp_stream.read(&mut buf[header_len..]).await?;
|
||||||
|
if n == 0 { anyhow::bail!("connection closed before handshake complete"); }
|
||||||
|
header_len += n;
|
||||||
|
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; }
|
||||||
|
}
|
||||||
|
let resp = String::from_utf8_lossy(&buf[..header_len]);
|
||||||
if !resp.contains("200 OK") {
|
if !resp.contains("200 OK") {
|
||||||
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
|
anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
|
||||||
|
let leftover = buf[headers_end..header_len].to_vec();
|
||||||
|
|
||||||
let (rx, tx) = tcp_stream.into_split();
|
let (rx, tx) = tcp_stream.into_split();
|
||||||
start_uot_loops(rx, tx)
|
start_uot_loops(rx, tx, leftover)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_uot_loops<R, W>(
|
fn start_uot_loops<R, W>(
|
||||||
mut net_rx: R,
|
mut net_rx: R,
|
||||||
mut net_tx: W
|
mut net_tx: W,
|
||||||
|
leftover: Vec<u8>
|
||||||
) -> Result<(mpsc::Sender<Bytes>, Arc<tokio::sync::Mutex<mpsc::Receiver<Bytes>>>)>
|
) -> Result<(mpsc::Sender<Bytes>, Arc<tokio::sync::Mutex<mpsc::Receiver<Bytes>>>)>
|
||||||
where
|
where
|
||||||
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
R: tokio::io::AsyncRead + Unpin + Send + 'static,
|
||||||
|
|
@ -156,16 +176,28 @@ where
|
||||||
|
|
||||||
// RX Loop (Network -> UoT -> App)
|
// RX Loop (Network -> UoT -> App)
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
let mut buffer = BytesMut::from(&leftover[..]);
|
||||||
loop {
|
loop {
|
||||||
let len = match net_rx.read_u16().await {
|
// Read more data if buffer has less than 2 bytes
|
||||||
Ok(l) => l,
|
while buffer.len() < 2 {
|
||||||
Err(_) => break,
|
let mut temp = [0u8; 1024];
|
||||||
};
|
match net_rx.read(&mut temp).await {
|
||||||
let mut buf = vec![0u8; len as usize];
|
Ok(0) | Err(_) => return,
|
||||||
if net_rx.read_exact(&mut buf).await.is_err() {
|
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||||
break;
|
}
|
||||||
}
|
}
|
||||||
if app_tx.send(Bytes::from(buf)).await.is_err() {
|
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
|
||||||
|
|
||||||
|
while buffer.len() < 2 + len {
|
||||||
|
let mut temp = [0u8; 1024];
|
||||||
|
match net_rx.read(&mut temp).await {
|
||||||
|
Ok(0) | Err(_) => return,
|
||||||
|
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let packet = buffer.split_to(2 + len);
|
||||||
|
if app_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -119,6 +119,9 @@ pub async fn handle_tcp_connection(
|
||||||
tcp_map.write().await.insert(peer_addr, tx);
|
tcp_map.write().await.insert(peer_addr, tx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
|
||||||
|
let leftover = &buf[headers_end..header_len];
|
||||||
|
|
||||||
// Process streams
|
// Process streams
|
||||||
let (mut read_half, mut write_half) = stream.into_split();
|
let (mut read_half, mut write_half) = stream.into_split();
|
||||||
|
|
||||||
|
|
@ -139,17 +142,36 @@ pub async fn handle_tcp_connection(
|
||||||
});
|
});
|
||||||
|
|
||||||
// Reader loop
|
// Reader loop
|
||||||
let mut len_buf = [0u8; 2];
|
let mut buffer = BytesMut::from(leftover);
|
||||||
loop {
|
loop {
|
||||||
if read_half.read_exact(&mut len_buf).await.is_err() {
|
while buffer.len() < 2 {
|
||||||
break;
|
let mut temp = [0u8; 1024];
|
||||||
|
match read_half.read(&mut temp).await {
|
||||||
|
Ok(0) | Err(_) => {
|
||||||
|
writer_task.abort();
|
||||||
|
tcp_map.write().await.remove(&peer_addr);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let len = u16::from_be_bytes(len_buf) as usize;
|
|
||||||
let mut packet_buf = vec![0u8; len];
|
let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize;
|
||||||
if read_half.read_exact(&mut packet_buf).await.is_err() {
|
|
||||||
break;
|
while buffer.len() < 2 + len {
|
||||||
|
let mut temp = [0u8; 1024];
|
||||||
|
match read_half.read(&mut temp).await {
|
||||||
|
Ok(0) | Err(_) => {
|
||||||
|
writer_task.abort();
|
||||||
|
tcp_map.write().await.remove(&peer_addr);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
Ok(n) => buffer.extend_from_slice(&temp[..n]),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if udp_tx.send((Bytes::from(packet_buf), peer_addr)).await.is_err() {
|
|
||||||
|
let packet = buffer.split_to(2 + len);
|
||||||
|
if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue