mirror of https://github.com/ospab/ostp.git
172 lines
5.3 KiB
Rust
172 lines
5.3 KiB
Rust
use anyhow::{Context, Result};
|
|
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
|
use hmac::{Hmac, Mac};
|
|
use sha2::Sha256;
|
|
use std::collections::HashMap;
|
|
use std::net::SocketAddr;
|
|
use std::sync::{Arc, RwLock as StdRwLock};
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::{mpsc, RwLock};
|
|
use tracing::{info, warn};
|
|
|
|
pub async fn handle_tcp_connection(
|
|
mut stream: TcpStream,
|
|
peer_addr: SocketAddr,
|
|
shared_keys: Arc<StdRwLock<HashMap<String, ()>>>,
|
|
udp_tx: mpsc::Sender<(Bytes, SocketAddr)>,
|
|
tcp_map: Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>,
|
|
) -> Result<()> {
|
|
// 1. Read HTTP Handshake
|
|
let mut buf = [0u8; 4096];
|
|
let mut header_len = 0;
|
|
loop {
|
|
let n = stream.read(&mut buf[header_len..]).await?;
|
|
if n == 0 {
|
|
anyhow::bail!("connection closed before handshake complete");
|
|
}
|
|
header_len += n;
|
|
if buf[..header_len].windows(4).any(|w| w == b"\r\n\r\n") {
|
|
break;
|
|
}
|
|
if header_len == buf.len() {
|
|
anyhow::bail!("handshake headers too large");
|
|
}
|
|
}
|
|
|
|
let headers_str = String::from_utf8_lossy(&buf[..header_len]);
|
|
|
|
// Fast-fail scanner bots
|
|
if !headers_str.starts_with("GET /stream HTTP/1.1\r\n") {
|
|
send_404(&mut stream).await?;
|
|
anyhow::bail!("invalid request line");
|
|
}
|
|
|
|
// Extract Authorization or Cookie for signature
|
|
let mut signature_base64 = None;
|
|
for line in headers_str.lines() {
|
|
let lower = line.to_ascii_lowercase();
|
|
if lower.starts_with("authorization: bearer ") {
|
|
signature_base64 = Some(line[22..].trim().to_string());
|
|
} else if lower.starts_with("cookie: ostp_token=") {
|
|
signature_base64 = Some(line[19..].trim().to_string());
|
|
}
|
|
}
|
|
|
|
let sig_b64 = match signature_base64 {
|
|
Some(s) => s,
|
|
None => {
|
|
send_404(&mut stream).await?;
|
|
anyhow::bail!("missing authorization");
|
|
}
|
|
};
|
|
|
|
let sig_bytes = match base64::Engine::decode(&base64::engine::general_purpose::STANDARD_NO_PAD, &sig_b64) {
|
|
Ok(b) => b,
|
|
Err(_) => {
|
|
send_404(&mut stream).await?;
|
|
anyhow::bail!("invalid base64 signature");
|
|
}
|
|
};
|
|
|
|
if sig_bytes.len() < 8 {
|
|
send_404(&mut stream).await?;
|
|
anyhow::bail!("signature too short");
|
|
}
|
|
|
|
let ts_bytes: [u8; 8] = sig_bytes[0..8].try_into().unwrap();
|
|
let client_ts = u64::from_be_bytes(ts_bytes);
|
|
let provided_mac = &sig_bytes[8..];
|
|
|
|
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
|
|
if client_ts > now + 30 || client_ts < now.saturating_sub(60) {
|
|
send_404(&mut stream).await?;
|
|
anyhow::bail!("timestamp out of bounds (replay protection)");
|
|
}
|
|
|
|
// Verify HMAC against known keys
|
|
let keys = {
|
|
let guard = shared_keys.read().unwrap();
|
|
guard.keys().cloned().collect::<Vec<_>>()
|
|
};
|
|
|
|
let mut authenticated = false;
|
|
for key in keys {
|
|
let mut mac = Hmac::<Sha256>::new_from_slice(key.as_bytes())
|
|
.unwrap_or_else(|_| Hmac::<Sha256>::new_from_slice(b"default").unwrap());
|
|
mac.update(&ts_bytes);
|
|
if mac.verify_slice(provided_mac).is_ok() {
|
|
authenticated = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if !authenticated {
|
|
send_404(&mut stream).await?;
|
|
anyhow::bail!("unauthorized (invalid HMAC)");
|
|
}
|
|
|
|
// Reply 200 OK
|
|
let response = "HTTP/1.1 200 OK\r\nConnection: keep-alive\r\n\r\n";
|
|
stream.write_all(response.as_bytes()).await?;
|
|
|
|
info!("UoT client authenticated from {}", peer_addr);
|
|
|
|
// Register this connection in the map
|
|
let (tx, mut rx) = mpsc::channel::<Bytes>(1024);
|
|
{
|
|
tcp_map.write().await.insert(peer_addr, tx);
|
|
}
|
|
|
|
// Process streams
|
|
let (mut read_half, mut write_half) = stream.into_split();
|
|
|
|
// Spawn writer task
|
|
let peer_clone = peer_addr;
|
|
let tcp_map_clone = tcp_map.clone();
|
|
let writer_task = tokio::spawn(async move {
|
|
while let Some(packet) = rx.recv().await {
|
|
let mut out = BytesMut::with_capacity(2 + packet.len());
|
|
out.put_u16(packet.len() as u16);
|
|
out.put_slice(&packet);
|
|
if write_half.write_all(&out).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
// Cleanup on writer exit
|
|
tcp_map_clone.write().await.remove(&peer_clone);
|
|
});
|
|
|
|
// Reader loop
|
|
let mut len_buf = [0u8; 2];
|
|
loop {
|
|
if read_half.read_exact(&mut len_buf).await.is_err() {
|
|
break;
|
|
}
|
|
let len = u16::from_be_bytes(len_buf) as usize;
|
|
let mut packet_buf = vec![0u8; len];
|
|
if read_half.read_exact(&mut packet_buf).await.is_err() {
|
|
break;
|
|
}
|
|
if udp_tx.send((Bytes::from(packet_buf), peer_addr)).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
|
|
writer_task.abort();
|
|
tcp_map.write().await.remove(&peer_addr);
|
|
Ok(())
|
|
}
|
|
|
|
async fn send_404(stream: &mut TcpStream) -> Result<()> {
|
|
let body = "Not Found";
|
|
let resp = format!(
|
|
"HTTP/1.1 404 Not Found\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
|
|
body.len(),
|
|
body
|
|
);
|
|
let _ = stream.write_all(resp.as_bytes()).await;
|
|
Ok(())
|
|
}
|