diff --git a/ostp-client/src/transport/xhttp.rs b/ostp-client/src/transport/xhttp.rs index 3455d8b..1b240ef 100644 --- a/ostp-client/src/transport/xhttp.rs +++ b/ostp-client/src/transport/xhttp.rs @@ -106,36 +106,56 @@ pub async fn connect_xhttp( tls_stream.write_all(req.as_bytes()).await?; tls_stream.flush().await?; - let mut buf = [0u8; 1024]; - let n = tls_stream.read(&mut buf).await?; - let resp = String::from_utf8_lossy(&buf[..n]); + let mut buf = vec![0u8; 4096]; + let mut header_len = 0; + loop { + let n = tls_stream.read(&mut buf[header_len..]).await?; + if n == 0 { anyhow::bail!("connection closed before handshake complete"); } + header_len += n; + if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; } + } + let resp = String::from_utf8_lossy(&buf[..header_len]); if !resp.contains("200 OK") { anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or("")); } + // Extract leftover payload if any + let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4; + let leftover = buf[headers_end..header_len].to_vec(); + // Split stream let (rx, tx) = tokio::io::split(tls_stream); - start_uot_loops(rx, tx) + start_uot_loops(rx, tx, leftover) } else { let mut tcp_stream = tcp_stream; tcp_stream.write_all(req.as_bytes()).await?; tcp_stream.flush().await?; - let mut buf = [0u8; 1024]; - let n = tcp_stream.read(&mut buf).await?; - let resp = String::from_utf8_lossy(&buf[..n]); + let mut buf = vec![0u8; 4096]; + let mut header_len = 0; + loop { + let n = tcp_stream.read(&mut buf[header_len..]).await?; + if n == 0 { anyhow::bail!("connection closed before handshake complete"); } + header_len += n; + if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") { break; } + } + let resp = String::from_utf8_lossy(&buf[..header_len]); if !resp.contains("200 OK") { anyhow::bail!("xHTTP handshake failed: expected 200 OK, got: {}", resp.lines().next().unwrap_or("")); } + let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4; + let leftover = buf[headers_end..header_len].to_vec(); + let (rx, tx) = tcp_stream.into_split(); - start_uot_loops(rx, tx) + start_uot_loops(rx, tx, leftover) } } fn start_uot_loops( mut net_rx: R, - mut net_tx: W + mut net_tx: W, + leftover: Vec ) -> Result<(mpsc::Sender, Arc>>)> where R: tokio::io::AsyncRead + Unpin + Send + 'static, @@ -156,16 +176,28 @@ where // RX Loop (Network -> UoT -> App) tokio::spawn(async move { + let mut buffer = BytesMut::from(&leftover[..]); loop { - let len = match net_rx.read_u16().await { - Ok(l) => l, - Err(_) => break, - }; - let mut buf = vec![0u8; len as usize]; - if net_rx.read_exact(&mut buf).await.is_err() { - break; + // Read more data if buffer has less than 2 bytes + while buffer.len() < 2 { + let mut temp = [0u8; 1024]; + match net_rx.read(&mut temp).await { + Ok(0) | Err(_) => return, + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } } - if app_tx.send(Bytes::from(buf)).await.is_err() { + let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize; + + while buffer.len() < 2 + len { + let mut temp = [0u8; 1024]; + match net_rx.read(&mut temp).await { + Ok(0) | Err(_) => return, + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } + } + + let packet = buffer.split_to(2 + len); + if app_tx.send(Bytes::from(packet[2..].to_vec())).await.is_err() { break; } } diff --git a/ostp-server/src/transport/uot.rs b/ostp-server/src/transport/uot.rs index eae9105..5db49cf 100644 --- a/ostp-server/src/transport/uot.rs +++ b/ostp-server/src/transport/uot.rs @@ -119,6 +119,9 @@ pub async fn handle_tcp_connection( tcp_map.write().await.insert(peer_addr, tx); } + let headers_end = buf[..header_len].windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4; + let leftover = &buf[headers_end..header_len]; + // Process streams let (mut read_half, mut write_half) = stream.into_split(); @@ -139,17 +142,36 @@ pub async fn handle_tcp_connection( }); // Reader loop - let mut len_buf = [0u8; 2]; + let mut buffer = BytesMut::from(leftover); loop { - if read_half.read_exact(&mut len_buf).await.is_err() { - break; + while buffer.len() < 2 { + let mut temp = [0u8; 1024]; + match read_half.read(&mut temp).await { + Ok(0) | Err(_) => { + writer_task.abort(); + tcp_map.write().await.remove(&peer_addr); + return Ok(()); + } + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } } - let len = u16::from_be_bytes(len_buf) as usize; - let mut packet_buf = vec![0u8; len]; - if read_half.read_exact(&mut packet_buf).await.is_err() { - break; + + let len = u16::from_be_bytes([buffer[0], buffer[1]]) as usize; + + while buffer.len() < 2 + len { + let mut temp = [0u8; 1024]; + match read_half.read(&mut temp).await { + Ok(0) | Err(_) => { + writer_task.abort(); + tcp_map.write().await.remove(&peer_addr); + return Ok(()); + } + Ok(n) => buffer.extend_from_slice(&temp[..n]), + } } - if udp_tx.send((Bytes::from(packet_buf), peer_addr)).await.is_err() { + + let packet = buffer.split_to(2 + len); + if udp_tx.send((Bytes::from(packet[2..].to_vec()), peer_addr)).await.is_err() { break; } }