mirror of https://github.com/ospab/ostp.git
225 lines
8.3 KiB
Rust
225 lines
8.3 KiB
Rust
use anyhow::{anyhow, Result};
|
|
use std::sync::Arc;
|
|
use crate::config::{ClientConfig, InboundConfig};
|
|
use crate::tunnel::router::{Router, Session};
|
|
use crate::tunnel::outbounds::OutboundManager;
|
|
use tokio::net::TcpListener;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::sync::watch;
|
|
|
|
pub async fn run_socks_inbound(
|
|
_config: ClientConfig,
|
|
inbound_config: InboundConfig,
|
|
router: Arc<Router>,
|
|
outbound_manager: Arc<OutboundManager>,
|
|
mut shutdown: watch::Receiver<bool>,
|
|
) -> Result<()> {
|
|
let InboundConfig::LocalProxy { tag, protocol, listen, port } = inbound_config else {
|
|
return Err(anyhow!("Invalid config for LocalProxy inbound"));
|
|
};
|
|
|
|
let bind_addr = format!("{}:{}", listen, port);
|
|
tracing::info!("Starting {} proxy inbound on {} (tag: {})", protocol, bind_addr, tag);
|
|
|
|
let listener = TcpListener::bind(&bind_addr).await?;
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = shutdown.changed() => {
|
|
tracing::info!("Local proxy inbound {} shutting down", tag);
|
|
break;
|
|
}
|
|
accept_res = listener.accept() => {
|
|
if let Ok((mut stream, client_addr)) = accept_res {
|
|
let rt = router.clone();
|
|
let om = outbound_manager.clone();
|
|
let proto = protocol.clone();
|
|
let inbound_tag = tag.clone();
|
|
|
|
tokio::spawn(async move {
|
|
if proto == "socks" {
|
|
if let Err(e) = handle_socks5_connection(&mut stream, &rt, &om, &inbound_tag, client_addr).await {
|
|
tracing::debug!("SOCKS5 handling error: {}", e);
|
|
}
|
|
} else if proto == "http" {
|
|
if let Err(e) = handle_http_connection(&mut stream, &rt, &om, &inbound_tag, client_addr).await {
|
|
tracing::debug!("HTTP proxy handling error: {}", e);
|
|
}
|
|
} else {
|
|
tracing::error!("Unknown local proxy protocol: {}", proto);
|
|
}
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_socks5_connection(
|
|
stream: &mut tokio::net::TcpStream,
|
|
router: &Arc<Router>,
|
|
outbound_manager: &Arc<OutboundManager>,
|
|
inbound_tag: &str,
|
|
client_addr: std::net::SocketAddr,
|
|
) -> Result<()> {
|
|
let mut buf = [0u8; 256];
|
|
|
|
// Read version and method selection
|
|
stream.read_exact(&mut buf[0..2]).await?;
|
|
if buf[0] != 0x05 {
|
|
return Err(anyhow!("Unsupported SOCKS version: {}", buf[0]));
|
|
}
|
|
|
|
let num_methods = buf[1] as usize;
|
|
stream.read_exact(&mut buf[0..num_methods]).await?;
|
|
|
|
// Reply with NO AUTHENTICATION REQUIRED (0x00)
|
|
stream.write_all(&[0x05, 0x00]).await?;
|
|
|
|
// Read the actual request
|
|
stream.read_exact(&mut buf[0..4]).await?;
|
|
if buf[0] != 0x05 || buf[1] != 0x01 { // Only CONNECT is supported
|
|
return Err(anyhow!("Unsupported SOCKS command"));
|
|
}
|
|
|
|
let atyp = buf[3];
|
|
let (target_host, ip_addr) = match atyp {
|
|
0x01 => { // IPv4
|
|
stream.read_exact(&mut buf[0..4]).await?;
|
|
let ip = std::net::Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
|
|
(ip.to_string(), Some(std::net::IpAddr::V4(ip)))
|
|
}
|
|
0x03 => { // Domain
|
|
stream.read_exact(&mut buf[0..1]).await?;
|
|
let domain_len = buf[0] as usize;
|
|
stream.read_exact(&mut buf[0..domain_len]).await?;
|
|
let domain = String::from_utf8_lossy(&buf[0..domain_len]).to_string();
|
|
(domain, None)
|
|
}
|
|
0x04 => { // IPv6
|
|
stream.read_exact(&mut buf[0..16]).await?;
|
|
let mut ip_bytes = [0u8; 16];
|
|
ip_bytes.copy_from_slice(&buf[0..16]);
|
|
let ip = std::net::Ipv6Addr::from(ip_bytes);
|
|
(ip.to_string(), Some(std::net::IpAddr::V6(ip)))
|
|
}
|
|
_ => return Err(anyhow!("Unsupported SOCKS address type: {}", atyp)),
|
|
};
|
|
|
|
stream.read_exact(&mut buf[0..2]).await?;
|
|
let target_port = u16::from_be_bytes([buf[0], buf[1]]);
|
|
|
|
let process_name = crate::tunnel::process_lookup::get_process_name_from_port(client_addr.port());
|
|
|
|
let session = Session {
|
|
protocol: "tcp".to_string(),
|
|
inbound_tag: inbound_tag.to_string(),
|
|
source_ip: Some(client_addr.ip()),
|
|
destination_ip: ip_addr,
|
|
destination_port: target_port,
|
|
sni: if atyp == 0x03 { Some(target_host.clone()) } else { None },
|
|
process_name,
|
|
};
|
|
|
|
let outbound_tag = router.route(&session);
|
|
tracing::info!("SOCKS5 TCP {} -> {}:{} routed to {}", client_addr, target_host, target_port, outbound_tag);
|
|
|
|
match outbound_manager.dial_tcp(&outbound_tag, &target_host, target_port).await {
|
|
Ok(mut remote_stream) => {
|
|
// Reply success
|
|
stream.write_all(&[0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]).await?;
|
|
|
|
// Forward data
|
|
tokio::io::copy_bidirectional(stream, &mut remote_stream).await?;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("SOCKS5 TCP dial failed to {}: {}", outbound_tag, e);
|
|
// Reply host unreachable
|
|
let _ = stream.write_all(&[0x05, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]).await;
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn handle_http_connection(
|
|
stream: &mut tokio::net::TcpStream,
|
|
router: &Arc<Router>,
|
|
outbound_manager: &Arc<OutboundManager>,
|
|
inbound_tag: &str,
|
|
client_addr: std::net::SocketAddr,
|
|
) -> Result<()> {
|
|
// Basic HTTP CONNECT implementation
|
|
let mut buf = [0u8; 4096];
|
|
let n = stream.read(&mut buf).await?;
|
|
if n == 0 { return Ok(()); }
|
|
|
|
let request = String::from_utf8_lossy(&buf[0..n]);
|
|
let mut lines = request.lines();
|
|
let first_line = lines.next().ok_or_else(|| anyhow!("Empty HTTP request"))?;
|
|
|
|
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
|
if parts.len() < 3 {
|
|
return Err(anyhow!("Invalid HTTP request line"));
|
|
}
|
|
|
|
let method = parts[0];
|
|
let target = parts[1]; // host:port for CONNECT, http://host:port/... for GET
|
|
|
|
let (target_host, target_port) = if method == "CONNECT" {
|
|
let parts: Vec<&str> = target.split(':').collect();
|
|
let host = parts[0].to_string();
|
|
let port = parts.get(1).unwrap_or(&"443").parse::<u16>().unwrap_or(443);
|
|
(host, port)
|
|
} else {
|
|
// Rudimentary GET parsing, ideally use httparse
|
|
if target.starts_with("http://") {
|
|
let without_scheme = &target[7..];
|
|
let host_part = without_scheme.split('/').next().unwrap_or(without_scheme);
|
|
let parts: Vec<&str> = host_part.split(':').collect();
|
|
let host = parts[0].to_string();
|
|
let port = parts.get(1).unwrap_or(&"80").parse::<u16>().unwrap_or(80);
|
|
(host, port)
|
|
} else {
|
|
return Err(anyhow!("Unsupported HTTP method/target: {} {}", method, target));
|
|
}
|
|
};
|
|
|
|
let process_name = crate::tunnel::process_lookup::get_process_name_from_port(client_addr.port());
|
|
|
|
let session = Session {
|
|
protocol: "tcp".to_string(),
|
|
inbound_tag: inbound_tag.to_string(),
|
|
source_ip: Some(client_addr.ip()),
|
|
destination_ip: None, // Could parse if IP
|
|
destination_port: target_port,
|
|
sni: Some(target_host.clone()),
|
|
process_name,
|
|
};
|
|
|
|
let outbound_tag = router.route(&session);
|
|
tracing::info!("HTTP TCP {} -> {}:{} routed to {}", client_addr, target_host, target_port, outbound_tag);
|
|
|
|
match outbound_manager.dial_tcp(&outbound_tag, &target_host, target_port).await {
|
|
Ok(mut remote_stream) => {
|
|
if method == "CONNECT" {
|
|
stream.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n").await?;
|
|
} else {
|
|
remote_stream.write_all(&buf[0..n]).await?;
|
|
}
|
|
|
|
tokio::io::copy_bidirectional(stream, &mut remote_stream).await?;
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("HTTP TCP dial failed to {}: {}", outbound_tag, e);
|
|
if method == "CONNECT" {
|
|
let _ = stream.write_all(b"HTTP/1.1 502 Bad Gateway\r\n\r\n").await;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|