From 05583e189e4a31a7243b47ae2d15739cd0bc5c3f Mon Sep 17 00:00:00 2001 From: ospab Date: Sun, 17 May 2026 21:05:44 +0300 Subject: [PATCH] =?UTF-8?q?feat:=20v0.2.0=20=E2=80=94=20BBR=20congestion?= =?UTF-8?q?=20control,=200-RTT=20session=20resumption,=20management=20REST?= =?UTF-8?q?=20API,=20fallback=20server,=20multi-listener?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Architecture: - BBR-inspired congestion controller (SlowStart/ProbeBandwidth/ProbeRTT phases) - 0-RTT session resumption with anti-replay ticket validation - Management REST API (axum): /api/users CRUD, /api/server/status, Bearer auth - TCP fallback proxy for anti-DPI camouflage (nginx/caddy passthrough) - Multi-listener: bind to multiple UDP addresses simultaneously - Per-user traffic stats with atomic counters and limit enforcement Code quality: - Structured logging: 0 eprintln in server/core/client, all tracing::{info,debug,warn,error} - 35 unit tests across congestion, resumption, relay, outbound, obfuscation - Removed dead code: kex.rs, unused dependencies (async-trait, x25519-dalek, rand_distr) - Modular server: api.rs, fallback.rs, outbound.rs, relay.rs extracted from monolithic lib.rs CLI: - --check: config validation - --generate-key: secure key generation (hex/base64, batch) - --links: share link generation from server config - --init: fallback section in server template Documentation: - README rewritten with architecture diagram, API examples, CLI reference - Wiki: Management-API (EN+RU), Configuration (EN+RU), Home (EN+RU) updated --- .gitignore | Bin 393 -> 393 bytes Cargo.lock | 378 +++++++++++++++-- Cargo.toml | 5 +- README.md | 220 +++++----- ostp-client/src/bridge.rs | 400 +----------------- ostp-client/src/lib.rs | 1 + ostp-client/src/signal.rs | 10 +- ostp-client/src/sysproxy.rs | 14 +- ostp-client/src/tunnel/linux_handler.rs | 2 +- ostp-client/src/tunnel/proxy.rs | 20 +- ostp-client/src/tunnel/wintun_handler.rs | 26 +- ostp-client/src/turn.rs | 397 ++++++++++++++++++ ostp-core/Cargo.toml | 2 - ostp-core/src/congestion.rs | 341 ++++++++++++++++ ostp-core/src/crypto/kex.rs | 66 --- ostp-core/src/crypto/mod.rs | 2 - ostp-core/src/lib.rs | 2 + ostp-core/src/protocol.rs | 55 ++- ostp-core/src/relay.rs | 80 ++++ ostp-core/src/resumption.rs | 307 ++++++++++++++ ostp-server/Cargo.toml | 2 + ostp-server/src/api.rs | 378 +++++++++++++++++ ostp-server/src/dispatcher.rs | 181 +++++++- ostp-server/src/fallback.rs | 87 ++++ ostp-server/src/lib.rs | 500 +++++------------------ ostp-server/src/outbound.rs | 323 +++++++++++++++ ostp-server/src/relay.rs | 114 ++++++ ostp/Cargo.toml | 2 + ostp/src/main.rs | 138 ++++++- 29 files changed, 2981 insertions(+), 1072 deletions(-) create mode 100644 ostp-client/src/turn.rs create mode 100644 ostp-core/src/congestion.rs delete mode 100644 ostp-core/src/crypto/kex.rs create mode 100644 ostp-core/src/resumption.rs create mode 100644 ostp-server/src/api.rs create mode 100644 ostp-server/src/fallback.rs create mode 100644 ostp-server/src/outbound.rs create mode 100644 ostp-server/src/relay.rs diff --git a/.gitignore b/.gitignore index 30795336258a3e0be0c6fc28d51013f38a3f4862..e963be2cf9cfa25208fa5193bb5aa7d8e8d6c392 100644 GIT binary patch delta 12 TcmeBV?quH3%E-dSz{LOn6^8<6 delta 12 TcmeBV?quH3%E-dYz{LOn6_WyL diff --git a/Cargo.lock b/Cargo.lock index 94e956b..4783bcb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -37,6 +37,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "android_system_properties" version = "0.1.5" @@ -103,15 +112,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] -name = "async-trait" -version = "0.1.89" +name = "atomic-waker" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" @@ -119,12 +123,70 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31b698c5f9a010f6573133b09e0de5408834d0c82f8d7475a89fc1867a71cd90" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08c78f31d7b1291f7ee735c1c6780ccde7785daae9a9206026862dab7d8792d1" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + [[package]] name = "blake2" version = "0.10.6" @@ -395,6 +457,15 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", +] + [[package]] name = "futures-core" version = "0.3.32" @@ -465,6 +536,86 @@ dependencies = [ "digest", ] +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "bytes", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "iana-time-zone" version = "0.1.65" @@ -699,12 +850,33 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "memchr" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "mio" version = "1.2.0" @@ -716,6 +888,15 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -757,6 +938,8 @@ dependencies = [ "serde", "serde_json", "tokio", + "tracing", + "tracing-subscriber", "url", ] @@ -783,7 +966,6 @@ name = "ostp-core" version = "0.1.70" dependencies = [ "anyhow", - "async-trait", "bytes", "chacha20poly1305", "hmac", @@ -792,7 +974,6 @@ dependencies = [ "snow", "thiserror", "tracing", - "x25519-dalek", ] [[package]] @@ -816,6 +997,7 @@ name = "ostp-server" version = "0.1.70" dependencies = [ "anyhow", + "axum", "bytes", "ostp-core", "rand", @@ -823,6 +1005,7 @@ dependencies = [ "serde_json", "socket2", "tokio", + "tower-http", "tracing", ] @@ -947,6 +1130,23 @@ dependencies = [ "getrandom", ] +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + [[package]] name = "rustc_version" version = "0.4.1" @@ -962,6 +1162,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "same-file" version = "1.0.6" @@ -1020,6 +1226,29 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha2" version = "0.10.9" @@ -1031,6 +1260,15 @@ dependencies = [ "digest", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "shlex" version = "1.3.0" @@ -1114,6 +1352,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" + [[package]] name = "synstructure" version = "0.13.2" @@ -1145,6 +1389,15 @@ dependencies = [ "syn", ] +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + [[package]] name = "tinystr" version = "0.8.3" @@ -1191,12 +1444,55 @@ dependencies = [ "serde", ] +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68d6fdd9f81c2819c9a8b0e0cd91660e7746a8e6ea2ba7c6b2b057985f6bcb51" +dependencies = [ + "bitflags", + "bytes", + "http", + "pin-project-lite", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", @@ -1220,6 +1516,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", ] [[package]] @@ -1268,6 +1594,12 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "version_check" version = "0.9.5" @@ -1493,18 +1825,6 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" -[[package]] -name = "x25519-dalek" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" -dependencies = [ - "curve25519-dalek", - "rand_core", - "serde", - "zeroize", -] - [[package]] name = "yoke" version = "0.8.2" @@ -1574,20 +1894,6 @@ name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" -dependencies = [ - "zeroize_derive", -] - -[[package]] -name = "zeroize_derive" -version = "1.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] [[package]] name = "zerotrie" diff --git a/Cargo.toml b/Cargo.toml index 1bf30ee..3975f48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,20 +12,17 @@ resolver = "2" [workspace.package] edition = "2021" license = "BSL 1.1" -version = "0.1.70" +version = "0.2.0" [workspace.dependencies] anyhow = "1.0" -async-trait = "0.1" bytes = "1.6" chacha20poly1305 = "0.10" rand = "0.8" -rand_distr = "0.4" snow = "0.9" thiserror = "1.0" tokio = { version = "1.37", features = ["rt-multi-thread", "macros", "net", "time", "io-util", "sync", "signal"] } tracing = "0.1" -x25519-dalek = "2" sha2 = "0.10" hmac = "0.12" portable-atomic = "1.10" diff --git a/README.md b/README.md index ceaf95c..33254a9 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,31 @@ # OSTP — Ospab Stealth Transport Protocol -[Русский язык](README.ru.md) +[Русский язык](README.ru.md) · [Wiki](https://github.com/ospab/ostp/wiki) · [Releases](https://github.com/ospab/ostp/releases) ![GitHub Release](https://img.shields.io/github/v/release/ospab/ostp?style=flat-square&color=blue) ![License: BSL 1.1](https://img.shields.io/badge/License-BSL%201.1-orange.svg?style=flat-square) ![Platform: Windows | Linux | macOS | Android](https://img.shields.io/badge/Platform-Windows%20%7C%20Linux%20%7C%20macOS%20%7C%20Android-green.svg?style=flat-square) +![Crypto](https://img.shields.io/badge/Crypto-Noise__NNpsk0-blueviolet?style=flat-square) +![Transport](https://img.shields.io/badge/Transport-UDP%20ARQ-informational?style=flat-square) -OSTP is a high-performance, censorship-resistant transport protocol designed to tunnel TCP traffic over UDP with full traffic obfuscation. It is resistant to Deep Packet Inspection (DPI), active probing, and statistical traffic analysis. +**OSTP** is a high-performance, censorship-resistant transport protocol designed to tunnel TCP traffic over UDP with full traffic obfuscation. Every byte on the wire — including packet headers — is cryptographically indistinguishable from random noise. Resistant to Deep Packet Inspection (DPI), active probing, and statistical traffic analysis. + +--- + +## Quick Install + +### Linux +```bash +bash <(curl -Ls https://raw.githubusercontent.com/ospab/ostp/master/scripts/install.sh) +``` + +### Windows (PowerShell, run as Administrator) +```powershell +irm https://raw.githubusercontent.com/ospab/ostp/master/scripts/install.ps1 | iex +``` + +### Manual Download +Download pre-built binaries for your platform from [GitHub Releases](https://github.com/ospab/ostp/releases). --- @@ -14,15 +33,19 @@ OSTP is a high-performance, censorship-resistant transport protocol designed to | Feature | Description | |---------|-------------| -| **Traffic Obfuscation** | Every packet — including headers — is indistinguishable from random noise on the wire. Session IDs and nonces are masked with per-packet HMAC-derived keys. | -| **Noise Protocol Handshake** | `Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s` — pre-shared key authenticated, forward-secret key exchange with no static identity exposure. | -| **Reliable UDP (ARQ)** | Selective ACK/NACK with rate-limited retransmission, configurable reorder buffer, and exponential backoff. Designed for 10 Gbps throughput. | -| **Multiplexed Streams** | Multiple logical TCP streams over a single encrypted UDP session, with per-stream flow control. | -| **Seamless Roaming** | Clients can switch networks (WiFi ↔ 4G) without session interruption — the server tracks session-ID, not IP address. | -| **TUN Mode** | Full-system VPN via `tun2socks` integration on Windows and Linux. All traffic is transparently routed through the tunnel. | +| **Full Traffic Obfuscation** | Every packet — including headers — is indistinguishable from random noise. Session IDs and nonces are masked with per-packet HMAC-derived keys. | +| **Noise Protocol Handshake** | `Noise_NNpsk0_25519_ChaChaPoly_BLAKE2s` — PSK-authenticated, forward-secret key exchange with no static identity exposure. | +| **Reliable UDP (ARQ)** | Selective ACK/NACK with rate-limited retransmission, configurable reorder buffer, and exponential backoff. | +| **Multiplexed Streams** | Multiple logical TCP streams over a single encrypted UDP session with per-stream flow control. | +| **Seamless Roaming** | Clients can switch networks (WiFi ↔ LTE) without session interruption — tracked by session-ID, not IP. | +| **Management API** | Built-in REST API for third-party panels (3x-ui, custom dashboards). Per-user stats, traffic limits, key CRUD. | +| **Fallback Server** | TCP fallback proxy to a web server — makes OSTP indistinguishable from nginx during active probing. | +| **Multi-Listener** | Bind to multiple addresses simultaneously (dual-stack IPv4/IPv6, multi-port). | +| **TUN Mode** | Full-system VPN via `tun2socks` integration. All traffic transparently routed through the tunnel. | | **TURN Relay** | RFC 5766 TURN support for environments where direct UDP is blocked. | -| **Hot-Reload** | Runtime config reload without restarting the process (access keys, exclusions, mux settings, TURN). | -| **Cross-Platform** | Windows, Linux, macOS, Android. Single binary, no runtime dependencies. | +| **Hot-Reload** | Runtime config reload without restart (access keys, exclusions, mux settings). | +| **Structured Logging** | `tracing`-based logging with `RUST_LOG` filtering. JSON/file/syslog output support. | +| **Cross-Platform** | Windows, Linux, macOS, Android, FreeBSD, MIPS, RISC-V. Single binary, no runtime dependencies. | --- @@ -48,149 +71,156 @@ OSTP is a high-performance, censorship-resistant transport protocol designed to │ Server │ │ │ ┌─────────────────────────────────────────┴───────────┐ │ │ │ Dispatcher │ │ -│ │ (Session lookup, roaming detection, replay guard) │ │ -│ └──────────────┬──────────────────────────────────────┘ │ -│ │ │ -│ ┌──────────────▾──────────────────┐ │ -│ │ Relay Loop (per-stream TCP) │──▸ Internet / Backend │ -│ └─────────────────────────────────┘ │ +│ │ (Session lookup, roaming, replay guard, per-user │ │ +│ │ traffic accounting, limit enforcement) │ │ +│ └──┬──────────────────────┬───────────────────────────┘ │ +│ │ │ │ +│ ┌──▾──────────────────┐ ┌─▾──────────────────────────┐ │ +│ │ Relay Loop │ │ Management API (REST) │ │ +│ │ (per-stream TCP) │ │ /api/users, /api/stats │ │ +│ │ ──▸ Internet │ │ Bearer token auth │ │ +│ └─────────────────────┘ └────────────────────────────┘ │ +│ │ +│ ┌──────────────────────────────────────────────────────┐ │ +│ │ Fallback TCP Proxy ──▸ nginx/caddy (anti-DPI) │ │ +│ └──────────────────────────────────────────────────────┘ │ └─────────────────────────────────────────────────────────────┘ ``` --- -## Installation +## Quick Start + +### 1. Generate config -### Linux ```bash -bash <(curl -Ls https://raw.githubusercontent.com/ospab/ostp/master/scripts/install.sh) +# On your VPS (server): +./ostp --init server + +# On your machine (client): +./ostp --init client ``` -### Windows (PowerShell, Administrator) -```powershell -irm https://raw.githubusercontent.com/ospab/ostp/master/scripts/install.ps1 | iex -``` +### 2. Edit config ---- - -## Configuration - -Generate a default config: -```bash -./ostp --init server # VPS -./ostp --init client # Local machine -``` - -### Server (`config.json`) +**Server** — set your access keys: ```jsonc { "mode": "server", "listen": "0.0.0.0:50000", "access_keys": ["YOUR_SECRET_KEY"], - "debug": false, - // Optional: forward traffic through an upstream proxy - "outbound": { - "enabled": false, - "protocol": "socks5", // "socks5" or "http" - "address": "127.0.0.1", - "port": 9050, - "default_action": "proxy" - } + "api": { "enabled": true, "bind": "127.0.0.1:9090", "token": "admin-token" }, + "fallback": { "enabled": false, "listen": "0.0.0.0:443", "target": "127.0.0.1:8080" } } ``` -### Client (`config.json`) +**Client** — point to your server: ```jsonc { "mode": "client", "server": "YOUR_SERVER_IP:50000", "access_key": "YOUR_SECRET_KEY", "socks5_bind": "127.0.0.1:1088", - "debug": false, - // TUN mode (full-system VPN) - "tun": { - "enable": false, - "dns": "1.1.1.1" - }, - // Multiplexing: spread traffic across multiple UDP sessions - "mux": { - "enabled": false, - "sessions": 2 - }, - // TURN relay for restricted networks - "turn": { - "enabled": false, - "server_addr": "turn.example.com:3478", - "username": "user", - "access_key": "pass" - }, - // Traffic exclusions (bypassed directly) - "exclude": { - "domains": ["example.local"], - "ips": ["192.168.0.0/16"] - } + "tun": { "enable": false, "dns": "1.1.1.1" } } ``` ---- - -## Usage +### 3. Run ```bash -# Start with config -./ostp --config config.json - -# Or just run (looks for config.json in current/binary directory) -./ostp +./ostp # Uses config.json in current directory +./ostp --config /path/to.json # Custom config path +./ostp --check # Validate config without running +./ostp --generate-key # Generate a new access key +./ostp --links # Print client share links ``` -### TUN Mode (Windows) -Requires `tun2socks.exe` in the same directory. Automatically requests Administrator privileges. - -### TUN Mode (Linux) -Requires root. Uses `tun2socks` binary (same directory or in `$PATH`). +### 4. Connect via share link (one-liner) +```bash +./ostp ostp://ACCESS_KEY@server.com:50000 +``` --- -## Protocol Specification +## Management API -See [docs/en/specification.md](docs/en/specification.md) for the full wire format, handshake flow, and ARQ semantics. +Built-in REST API for building panels and dashboards. -### Quick Summary +```bash +# Server status +curl -H "Authorization: Bearer mytoken" http://127.0.0.1:9090/api/server/status + +# List all users with traffic stats +curl -H "Authorization: Bearer mytoken" http://127.0.0.1:9090/api/users + +# Create a user with 10GB traffic limit +curl -X POST -H "Authorization: Bearer mytoken" \ + -H "Content-Type: application/json" \ + -d '{"limit_bytes": 10737418240}' \ + http://127.0.0.1:9090/api/users +``` + +Full API reference: [Management API](https://github.com/ospab/ostp/wiki/Management-API) + +--- + +## CLI Reference + +``` +ostp [OPTIONS] [URL] + +Options: + --config Config file path (default: config.json) + --init Generate template config (server/client) + --check Validate configuration and exit + -g, --generate-key Generate a secure access key + -c, --count Number of keys to generate (default: 1) + --format Key format: hex, base64 (default: hex) + --links Print client share links from server config + +Arguments: + [URL] Connect via share link: ostp://KEY@HOST:PORT +``` + +--- + +## Protocol Summary | Layer | Mechanism | |-------|-----------| | Key Exchange | Noise NNpsk0 (X25519 + ChaChaPoly + BLAKE2s) | | Encryption | ChaCha20-Poly1305 AEAD per-packet | -| Header Obfuscation | HMAC-SHA256 derived per-packet mask over session_id + nonce | +| Header Obfuscation | HMAC-SHA256 derived per-packet mask | | Reliability | Selective ACK with cumulative + SACK ranges | -| Retransmission | Rate-limited NACK (30ms cooldown) + exponential backoff RTO | -| Flow Control | In-flight window (retransmittable frames only) | +| Retransmission | Rate-limited NACK + exponential backoff RTO | | Keepalive | Ping/Pong with RTT measurement every 5s | -| Session Timeout | 60s inactivity on client, 300s on server | --- ## Building from Source ```bash -# Prerequisites: Rust toolchain (1.75+) +# Prerequisites: Rust 1.75+ cargo build --release -# Cross-compile for Linux (from Windows/macOS) +# Cross-compile for Linux cross build --release --target x86_64-unknown-linux-gnu + +# Run tests +cargo test -p ostp-core -p ostp-server ``` --- ## Documentation -- [Architecture Overview](docs/en/architecture.md) -- [Protocol Specification](docs/en/specification.md) -- [Obfuscation Design](docs/en/obfuscation.md) -- [Server Administration](docs/en/server.md) -- [Client Configuration](docs/en/client.md) -- [Integration Guide](docs/en/integrations.md) +- **[Wiki](https://github.com/ospab/ostp/wiki)** — Full documentation +- [Installation](https://github.com/ospab/ostp/wiki/Installation) +- [Configuration Reference](https://github.com/ospab/ostp/wiki/Configuration) +- [Management API](https://github.com/ospab/ostp/wiki/Management-API) +- [Protocol Design](https://github.com/ospab/ostp/wiki/Protocol-Design) +- [Building from Source](https://github.com/ospab/ostp/wiki/Building-from-Source) +- [FAQ](https://github.com/ospab/ostp/wiki/FAQ) --- diff --git a/ostp-client/src/bridge.rs b/ostp-client/src/bridge.rs index 57655cd..3de949e 100644 --- a/ostp-client/src/bridge.rs +++ b/ostp-client/src/bridge.rs @@ -147,7 +147,7 @@ impl Bridge { Ok(a) => a, Err(e) => { let _ = tx.send(UiEvent::Log(format!("Protocol decrypt error: {e}"))).await; - eprintln!("[ostp] Inbound protocol error (session {}): {}", session_index, e); + tracing::warn!("Inbound protocol error (session {}): {}", session_index, e); continue; } }; @@ -508,7 +508,7 @@ impl Bridge { } } Err(e) => { - eprintln!("[ostp] Protocol error packing outbound stream_id={}: {}", stream_id, e); + tracing::warn!("Protocol error packing outbound stream_id={}: {}", stream_id, e); let _ = tx.send(UiEvent::Log(format!("Protocol error packing TCP: {e}"))).await; } } @@ -619,7 +619,7 @@ impl Bridge { let _ = sock.set_send_buffer_size(33554432); // 32MB let actual_recv = sock.recv_buffer_size().unwrap_or(0); let actual_send = sock.send_buffer_size().unwrap_or(0); - eprintln!("[ostp] UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024); + tracing::info!("UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024); sock.bind(&addr.into())?; sock.set_nonblocking(true)?; let socket = UdpSocket::from_std(sock.into())?; @@ -632,7 +632,7 @@ impl Bridge { }; tx.send(UiEvent::Log(format!("Allocating TURN relay via {}", turn_addr))).await.ok(); - match perform_turn_allocation(&socket, &turn_addr, &self.turn_username, &self.turn_password, &self.server_addr).await { + match crate::turn::perform_turn_allocation(&socket, &turn_addr, &self.turn_username, &self.turn_password, &self.server_addr).await { Ok(relay_addr) => { tx.send(UiEvent::Log(format!("TURN relay allocated ({})", relay_addr))).await.ok(); // Re-connect the UDP socket to the TURN server so all sends go through it. @@ -677,7 +677,7 @@ impl Bridge { .await .context("handshake timeout waiting server response")??; self.metrics.bytes_recv.fetch_add(size as u64, Ordering::Relaxed); - eprintln!("[ostp] Handshake response received: {} bytes", size); + tracing::info!("Handshake response received: {} bytes", size); let inbound = if self.turn_enabled && size >= 4 && buf[0] == 0x40 && buf[1] == 0x00 { let len = u16::from_be_bytes([buf[2], buf[3]]) as usize; @@ -691,7 +691,7 @@ impl Bridge { }; machine.on_event(OstpEvent::Inbound(inbound))?; let rtt_ms = start.elapsed().as_secs_f64() * 1000.0; - eprintln!("[ostp] Handshake complete: session={:#010x} rtt={:.1}ms", session_id, rtt_ms); + tracing::info!("Handshake complete: session={:#010x} rtt={:.1}ms", session_id, rtt_ms); Ok((socket, machine, rtt_ms)) } @@ -721,391 +721,3 @@ fn next_profile(current: TrafficProfile) -> TrafficProfile { } } -/// Real RFC-5766 TURN allocation with HMAC-SHA1 long-term credentials. -/// -/// Flow: -/// 1. Send Allocate (unauthenticated) → get 401 with realm + nonce -/// 2. Compute HMAC-SHA1 key = MD5(username:realm:password) -/// 3. Re-send Allocate with MESSAGE-INTEGRITY -/// 4. Extract XOR-RELAYED-ADDRESS from success response -/// 5. Send ChannelBind to bind channel 0x4000 to the OSTP server addr -/// -/// Returns the relay address string like "1.2.3.4:12345". -async fn perform_turn_allocation( - socket: &UdpSocket, - turn_addr: &str, - username: &str, - password: &str, - ostp_server_addr: &str, -) -> anyhow::Result { - use std::net::ToSocketAddrs; - - let turn_sock: std::net::SocketAddr = turn_addr - .to_socket_addrs() - .map_err(|e| anyhow::anyhow!("TURN DNS resolution failed: {e}"))? - .next() - .ok_or_else(|| anyhow::anyhow!("TURN addr resolved to nothing"))?; - - let transaction_id = { - use rand::Rng; - let mut id = [0u8; 12]; - rand::thread_rng().fill(&mut id); - id - }; - - // Helper: build a minimal STUN/TURN message - fn build_stun_msg(msg_type: u16, tx_id: &[u8; 12], attrs: &[u8]) -> Vec { - let mut msg = Vec::with_capacity(20 + attrs.len()); - msg.extend_from_slice(&msg_type.to_be_bytes()); - msg.extend_from_slice(&(attrs.len() as u16).to_be_bytes()); - msg.extend_from_slice(&0x2112A442_u32.to_be_bytes()); // Magic Cookie - msg.extend_from_slice(tx_id); - msg.extend_from_slice(attrs); - msg - } - - // Helper: encode a STUN attribute (type, length-padded value) - fn stun_attr(attr_type: u16, value: &[u8]) -> Vec { - let mut out = Vec::new(); - out.extend_from_slice(&attr_type.to_be_bytes()); - out.extend_from_slice(&(value.len() as u16).to_be_bytes()); - out.extend_from_slice(value); - // Pad to 4-byte boundary - let pad = (4 - (value.len() % 4)) % 4; - out.extend(std::iter::repeat(0u8).take(pad)); - out - } - - // ── Step 1: unauthenticated Allocate ───────────────────────────── - // REQUESTED-TRANSPORT attr: 0x0019, value = 17 (UDP) + 3 reserved bytes - let req_transport = stun_attr(0x0019, &[17u8, 0, 0, 0]); - let alloc_req = build_stun_msg(0x0003, &transaction_id, &req_transport); - - socket.send_to(&alloc_req, turn_sock).await - .map_err(|e| anyhow::anyhow!("TURN send Allocate failed: {e}"))?; - - let mut buf = [0u8; 2048]; - let (n, _) = timeout(Duration::from_millis(3000), socket.recv_from(&mut buf)) - .await - .map_err(|_| anyhow::anyhow!("TURN Allocate response timed out"))? - .map_err(|e| anyhow::anyhow!("TURN recv failed: {e}"))?; - - let resp = &buf[..n]; - if resp.len() < 20 { - anyhow::bail!("TURN response too short"); - } - - let msg_type = u16::from_be_bytes([resp[0], resp[1]]); - - // 0x0113 = Allocate Error Response - if msg_type != 0x0113 { - anyhow::bail!("Expected TURN 401 error response, got type 0x{:04x}", msg_type); - } - - // Parse realm and nonce from the error response attributes - let mut realm: Option = None; - let mut nonce: Option = None; - { - let mut idx = 20usize; - while idx + 4 <= n { - let atype = u16::from_be_bytes([resp[idx], resp[idx + 1]]); - let alen = u16::from_be_bytes([resp[idx + 2], resp[idx + 3]]) as usize; - idx += 4; - if idx + alen > n { break; } - let val = &resp[idx..idx + alen]; - match atype { - 0x0014 => realm = Some(String::from_utf8_lossy(val).to_string()), // REALM - 0x0015 => nonce = Some(String::from_utf8_lossy(val).to_string()), // NONCE - _ => {} - } - idx += alen; - let pad = (4 - (alen % 4)) % 4; - idx += pad; - } - } - - let realm = realm.ok_or_else(|| anyhow::anyhow!("TURN 401: no REALM in response"))?; - let nonce = nonce.ok_or_else(|| anyhow::anyhow!("TURN 401: no NONCE in response"))?; - - // ── Step 2: Compute long-term credential key per RFC 5389 §15.4 ── - // key = MD5(username ":" realm ":" password) - let key_input = format!("{}:{}:{}", username, realm, password); - let key = md5_hash(key_input.as_bytes()); - - // HMAC-SHA1 of the message (MESSAGE-INTEGRITY attribute, RFC 5389 §15.4) - // We build the message without the integrity attr, compute HMAC, then append. - let mut attrs2 = Vec::new(); - attrs2.extend_from_slice(&stun_attr(0x0006, username.as_bytes())); // USERNAME - attrs2.extend_from_slice(&stun_attr(0x0014, realm.as_bytes())); // REALM - attrs2.extend_from_slice(&stun_attr(0x0015, nonce.as_bytes())); // NONCE - attrs2.extend_from_slice(&req_transport); // REQUESTED-TRANSPORT - - // For MESSAGE-INTEGRITY we need the full message length including the MI attr (24 bytes) - let mi_placeholder_len = attrs2.len() + 4 + 20; // +4 header, +20 HMAC-SHA1 - let mut msg_for_hmac = build_stun_msg(0x0003, &transaction_id, &attrs2); - // Set length field to include the upcoming MI attr - let new_len = (mi_placeholder_len - 20) as u16; // total attrs length including MI - msg_for_hmac[2..4].copy_from_slice(&new_len.to_be_bytes()); - // Append MI header (without value) - msg_for_hmac.extend_from_slice(&0x0008_u16.to_be_bytes()); // attr type - msg_for_hmac.extend_from_slice(&20_u16.to_be_bytes()); // attr len - - let hmac = hmac_sha1(&key, &msg_for_hmac); - let mut final_attrs = attrs2.clone(); - final_attrs.extend_from_slice(&stun_attr(0x0008, &hmac)); // MESSAGE-INTEGRITY - - let alloc_req2 = build_stun_msg(0x0003, &transaction_id, &final_attrs); - - socket.send_to(&alloc_req2, turn_sock).await - .map_err(|e| anyhow::anyhow!("TURN authenticated Allocate send failed: {e}"))?; - - let (n2, _) = timeout(Duration::from_millis(5000), socket.recv_from(&mut buf)) - .await - .map_err(|_| anyhow::anyhow!("TURN authenticated Allocate timed out"))? - .map_err(|e| anyhow::anyhow!("TURN recv2 failed: {e}"))?; - - let resp2 = &buf[..n2]; - if resp2.len() < 20 { - anyhow::bail!("TURN auth response too short"); - } - let msg_type2 = u16::from_be_bytes([resp2[0], resp2[1]]); - // 0x0103 = Allocate Success Response - if msg_type2 != 0x0103 { - anyhow::bail!("TURN Allocate auth failed, response type 0x{:04x}", msg_type2); - } - - // ── Step 3: Parse XOR-RELAYED-ADDRESS ──────────────────────────── - let relay_addr_str = { - let mut relayed: Option = None; - let mut idx = 20usize; - while idx + 4 <= n2 { - let atype = u16::from_be_bytes([resp2[idx], resp2[idx + 1]]); - let alen = u16::from_be_bytes([resp2[idx + 2], resp2[idx + 3]]) as usize; - idx += 4; - if idx + alen > n2 { break; } - let val = &resp2[idx..idx + alen]; - if atype == 0x0016 && alen >= 8 { // XOR-RELAYED-ADDRESS - let x_port = u16::from_be_bytes([val[2], val[3]]) ^ 0x2112; - let x_ip = [val[4], val[5], val[6], val[7]]; - let ip = std::net::Ipv4Addr::new( - x_ip[0] ^ 0x21, x_ip[1] ^ 0x12, x_ip[2] ^ 0xA4, x_ip[3] ^ 0x42, - ); - relayed = Some(format!("{}:{}", ip, x_port)); - } - idx += alen; - let pad = (4 - (alen % 4)) % 4; - idx += pad; - } - relayed.ok_or_else(|| anyhow::anyhow!("TURN: no XOR-RELAYED-ADDRESS in response"))? - }; - - // ── Step 4: ChannelBind to the OSTP server ──────────────────────── - // ChannelBind binds channel 0x4000 to the peer (OSTP server). - // After this, all UDP data we send as ChannelData (4 bytes header + payload) - // will be forwarded by the TURN server to the OSTP server transparently. - let ostp_sock: std::net::SocketAddr = ostp_server_addr - .to_socket_addrs() - .map_err(|e| anyhow::anyhow!("OSTP server DNS resolution failed: {e}"))? - .next() - .ok_or_else(|| anyhow::anyhow!("OSTP server addr resolved to nothing"))?; - - let channel_number: u16 = 0x4000; - let mut peer_addr_attr = Vec::new(); - peer_addr_attr.push(0u8); // reserved - peer_addr_attr.push(0x01u8); // family IPv4 - peer_addr_attr.extend_from_slice(&(ostp_sock.port() ^ 0x2112).to_be_bytes()); // XOR port - if let std::net::IpAddr::V4(ipv4) = ostp_sock.ip() { - let octets = ipv4.octets(); - peer_addr_attr.push(octets[0] ^ 0x21); - peer_addr_attr.push(octets[1] ^ 0x12); - peer_addr_attr.push(octets[2] ^ 0xA4); - peer_addr_attr.push(octets[3] ^ 0x42); - } else { - anyhow::bail!("TURN ChannelBind: IPv6 OSTP server not yet supported"); - } - - let mut cb_attrs = Vec::new(); - // CHANNEL-NUMBER attr: 0x000C - cb_attrs.extend_from_slice(&stun_attr(0x000C, &[ - (channel_number >> 8) as u8, channel_number as u8, 0, 0 - ])); - // XOR-PEER-ADDRESS attr: 0x0012 - cb_attrs.extend_from_slice(&stun_attr(0x0012, &peer_addr_attr)); - cb_attrs.extend_from_slice(&stun_attr(0x0006, username.as_bytes())); - cb_attrs.extend_from_slice(&stun_attr(0x0014, realm.as_bytes())); - cb_attrs.extend_from_slice(&stun_attr(0x0015, nonce.as_bytes())); - - // Compute MESSAGE-INTEGRITY for ChannelBind too - let mi_len2 = cb_attrs.len() + 4 + 20; - let mut cb_for_hmac = build_stun_msg(0x0009, &transaction_id, &cb_attrs); - cb_for_hmac[2..4].copy_from_slice(&((mi_len2 - 20) as u16).to_be_bytes()); - cb_for_hmac.extend_from_slice(&0x0008_u16.to_be_bytes()); - cb_for_hmac.extend_from_slice(&20_u16.to_be_bytes()); - let cb_hmac = hmac_sha1(&key, &cb_for_hmac); - cb_attrs.extend_from_slice(&stun_attr(0x0008, &cb_hmac)); - - let cb_req = build_stun_msg(0x0009, &transaction_id, &cb_attrs); - socket.send_to(&cb_req, turn_sock).await - .map_err(|e| anyhow::anyhow!("TURN ChannelBind send failed: {e}"))?; - - let (n3, _) = timeout(Duration::from_millis(3000), socket.recv_from(&mut buf)) - .await - .map_err(|_| anyhow::anyhow!("TURN ChannelBind response timed out"))? - .map_err(|e| anyhow::anyhow!("TURN ChannelBind recv failed: {e}"))?; - - let resp3 = &buf[..n3]; - if resp3.len() < 4 { - anyhow::bail!("TURN ChannelBind response too short"); - } - let cb_resp_type = u16::from_be_bytes([resp3[0], resp3[1]]); - // 0x0109 = ChannelBind Success Response - if cb_resp_type != 0x0109 { - anyhow::bail!("TURN ChannelBind failed, response type 0x{:04x}", cb_resp_type); - } - - Ok(relay_addr_str) -} - -/// Pure-Rust MD5 hash (16 bytes). Used for TURN long-term credential key derivation. -fn md5_hash(input: &[u8]) -> [u8; 16] { - // RFC 1321 MD5 constants - const S: [u32; 64] = [ - 7,12,17,22, 7,12,17,22, 7,12,17,22, 7,12,17,22, - 5, 9,14,20, 5, 9,14,20, 5, 9,14,20, 5, 9,14,20, - 4,11,16,23, 4,11,16,23, 4,11,16,23, 4,11,16,23, - 6,10,15,21, 6,10,15,21, 6,10,15,21, 6,10,15,21, - ]; - const K: [u32; 64] = [ - 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, - 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, - 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, - 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, - 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, - 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, - 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, - 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, - 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, - 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, - 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, - ]; - - let msg_len = input.len(); - let bit_len = (msg_len as u64) * 8; - - let mut padded = input.to_vec(); - padded.push(0x80); - while padded.len() % 64 != 56 { - padded.push(0); - } - padded.extend_from_slice(&bit_len.to_le_bytes()); - - let mut a0: u32 = 0x67452301; - let mut b0: u32 = 0xefcdab89; - let mut c0: u32 = 0x98badcfe; - let mut d0: u32 = 0x10325476; - - for chunk in padded.chunks(64) { - let mut m = [0u32; 16]; - for (i, item) in m.iter_mut().enumerate() { - *item = u32::from_le_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]); - } - let (mut a, mut b, mut c, mut d) = (a0, b0, c0, d0); - for i in 0..64usize { - let (f, g) = match i { - 0..=15 => ((b & c) | (!b & d), i), - 16..=31 => ((d & b) | (!d & c), (5*i + 1) % 16), - 32..=47 => (b ^ c ^ d, (3*i + 5) % 16), - _ => (c ^ (b | !d), (7*i) % 16), - }; - let temp = d; - d = c; - c = b; - b = b.wrapping_add((a.wrapping_add(f).wrapping_add(K[i]).wrapping_add(m[g])).rotate_left(S[i])); - a = temp; - } - a0 = a0.wrapping_add(a); - b0 = b0.wrapping_add(b); - c0 = c0.wrapping_add(c); - d0 = d0.wrapping_add(d); - } - - let mut result = [0u8; 16]; - result[0..4].copy_from_slice(&a0.to_le_bytes()); - result[4..8].copy_from_slice(&b0.to_le_bytes()); - result[8..12].copy_from_slice(&c0.to_le_bytes()); - result[12..16].copy_from_slice(&d0.to_le_bytes()); - result -} - -/// HMAC-SHA1 for TURN MESSAGE-INTEGRITY (RFC 2104 + RFC 5389 §15.4). -fn hmac_sha1(key: &[u8], message: &[u8]) -> [u8; 20] { - const BLOCK_SIZE: usize = 64; - - let mut k = [0u8; BLOCK_SIZE]; - if key.len() > BLOCK_SIZE { - let h = sha1_hash(key); - k[..20].copy_from_slice(&h); - } else { - k[..key.len()].copy_from_slice(key); - } - - let mut ipad = [0u8; BLOCK_SIZE]; - let mut opad = [0u8; BLOCK_SIZE]; - for i in 0..BLOCK_SIZE { - ipad[i] = k[i] ^ 0x36; - opad[i] = k[i] ^ 0x5C; - } - - let mut inner = ipad.to_vec(); - inner.extend_from_slice(message); - let inner_hash = sha1_hash(&inner); - - let mut outer = opad.to_vec(); - outer.extend_from_slice(&inner_hash); - sha1_hash(&outer) -} - -/// Pure-Rust SHA-1 (RFC 3174). -fn sha1_hash(input: &[u8]) -> [u8; 20] { - let msg_len = input.len(); - let bit_len = (msg_len as u64) * 8; - let mut padded = input.to_vec(); - padded.push(0x80); - while padded.len() % 64 != 56 { - padded.push(0); - } - padded.extend_from_slice(&bit_len.to_be_bytes()); - - let mut h: [u32; 5] = [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0]; - - for chunk in padded.chunks(64) { - let mut w = [0u32; 80]; - for i in 0..16 { - w[i] = u32::from_be_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]); - } - for i in 16..80 { - w[i] = (w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]).rotate_left(1); - } - let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]); - for i in 0..80usize { - let (f, k) = match i { - 0..=19 => ((b & c) | (!b & d), 0x5A827999u32), - 20..=39 => (b ^ c ^ d, 0x6ED9EBA1), - 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC), - _ => (b ^ c ^ d, 0xCA62C1D6), - }; - let temp = a.rotate_left(5).wrapping_add(f).wrapping_add(e).wrapping_add(k).wrapping_add(w[i]); - e = d; d = c; c = b.rotate_left(30); b = a; a = temp; - } - h[0] = h[0].wrapping_add(a); h[1] = h[1].wrapping_add(b); - h[2] = h[2].wrapping_add(c); h[3] = h[3].wrapping_add(d); - h[4] = h[4].wrapping_add(e); - } - - let mut out = [0u8; 20]; - for (i, &v) in h.iter().enumerate() { - out[i*4..(i+1)*4].copy_from_slice(&v.to_be_bytes()); - } - out -} - diff --git a/ostp-client/src/lib.rs b/ostp-client/src/lib.rs index a29e3f8..70193fc 100644 --- a/ostp-client/src/lib.rs +++ b/ostp-client/src/lib.rs @@ -4,4 +4,5 @@ pub mod config; pub mod signal; pub mod sysproxy; pub mod tunnel; +pub mod turn; pub mod runner; diff --git a/ostp-client/src/signal.rs b/ostp-client/src/signal.rs index d91e3ca..caaf1fc 100644 --- a/ostp-client/src/signal.rs +++ b/ostp-client/src/signal.rs @@ -9,10 +9,10 @@ pub async fn wait_for_shutdown_signal() -> Result<()> { tokio::select! { _ = sigterm.recv() => { - eprintln!("[ostp] Received SIGTERM, shutting down"); + tracing::info!("Received SIGTERM, shutting down"); } _ = sigint.recv() => { - eprintln!("[ostp] Received SIGINT, shutting down"); + tracing::info!("Received SIGINT, shutting down"); } } @@ -30,19 +30,19 @@ pub async fn wait_for_shutdown_signal() -> Result<()> { tokio::select! { res = c_c.recv() => { - eprintln!("[ostp] Received Ctrl+C, shutting down"); + tracing::info!("Received Ctrl+C, shutting down"); if res.is_none() { std::future::pending::<()>().await; } } res = c_close.recv() => { - eprintln!("[ostp] Received console close event, shutting down"); + tracing::info!("Received console close event, shutting down"); if res.is_none() { std::future::pending::<()>().await; } } res = c_break.recv() => { - eprintln!("[ostp] Received Ctrl+Break, shutting down"); + tracing::info!("Received Ctrl+Break, shutting down"); if res.is_none() { std::future::pending::<()>().await; } diff --git a/ostp-client/src/sysproxy.rs b/ostp-client/src/sysproxy.rs index 3c26f24..0cc5a01 100644 --- a/ostp-client/src/sysproxy.rs +++ b/ostp-client/src/sysproxy.rs @@ -24,7 +24,7 @@ const INTERNET_OPTION_REFRESH: u32 = 37; #[cfg(target_os = "windows")] pub fn enable_windows_proxy(proxy_addr: &str) { - eprintln!("[ostp] Enabling Windows system proxy: {}", proxy_addr); + tracing::info!("Enabling Windows system proxy: {}", proxy_addr); let result = Command::new("reg") .creation_flags(CREATE_NO_WINDOW) @@ -39,9 +39,9 @@ pub fn enable_windows_proxy(proxy_addr: &str) { .output(); match result { Ok(out) if !out.status.success() => { - eprintln!("[ostp] Failed to set ProxyEnable: {}", String::from_utf8_lossy(&out.stderr)); + tracing::error!("Failed to set ProxyEnable: {}", String::from_utf8_lossy(&out.stderr)); } - Err(e) => eprintln!("[ostp] Failed to execute reg.exe (ProxyEnable): {}", e), + Err(e) => tracing::error!("Failed to execute reg.exe (ProxyEnable): {}", e), _ => {} } @@ -58,9 +58,9 @@ pub fn enable_windows_proxy(proxy_addr: &str) { .output(); match result { Ok(out) if !out.status.success() => { - eprintln!("[ostp] Failed to set ProxyServer: {}", String::from_utf8_lossy(&out.stderr)); + tracing::error!("Failed to set ProxyServer: {}", String::from_utf8_lossy(&out.stderr)); } - Err(e) => eprintln!("[ostp] Failed to execute reg.exe (ProxyServer): {}", e), + Err(e) => tracing::error!("Failed to execute reg.exe (ProxyServer): {}", e), _ => {} } @@ -78,12 +78,12 @@ pub fn enable_windows_proxy(proxy_addr: &str) { .output(); refresh_wininet(); - eprintln!("[ostp] System proxy enabled successfully"); + tracing::info!("System proxy enabled successfully"); } #[cfg(target_os = "windows")] pub fn disable_windows_proxy() { - eprintln!("[ostp] Disabling Windows system proxy"); + tracing::info!("Disabling Windows system proxy"); let _ = Command::new("reg") .creation_flags(CREATE_NO_WINDOW) .args([ diff --git a/ostp-client/src/tunnel/linux_handler.rs b/ostp-client/src/tunnel/linux_handler.rs index 62c1923..40d6947 100644 --- a/ostp-client/src/tunnel/linux_handler.rs +++ b/ostp-client/src/tunnel/linux_handler.rs @@ -203,7 +203,7 @@ pub async fn run_linux_tunnel( tokio::spawn(async move { let reader = BufReader::new(stderr); for line in reader.lines().map_while(Result::ok) { - eprintln!("[tun2socks-err] {}", line); + tracing::warn!("tun2socks: {}", line); } }); } diff --git a/ostp-client/src/tunnel/proxy.rs b/ostp-client/src/tunnel/proxy.rs index 9809bfa..617642d 100644 --- a/ostp-client/src/tunnel/proxy.rs +++ b/ostp-client/src/tunnel/proxy.rs @@ -23,8 +23,8 @@ pub async fn run_local_socks5_proxy( .with_context(|| format!("failed to bind local HTTP/SOCKS5 proxy at {}", cfg.bind_addr))?; if debug { - eprintln!("[ostp] local HTTP/SOCKS5 proxy listening at {}", cfg.bind_addr); - eprintln!("[ostp] Windows system proxy: set HTTP proxy to {}. tun2socks: SOCKS5 on same address.", cfg.bind_addr); + tracing::info!("local HTTP/SOCKS5 proxy listening at {}", cfg.bind_addr); + tracing::info!("Windows system proxy: set HTTP proxy to {}. tun2socks: SOCKS5 on same address.", cfg.bind_addr); } let matcher = ExclusionMatcher::new(&exclusions); @@ -75,7 +75,7 @@ pub async fn run_local_socks5_proxy( && !msg.contains("unsupported SOCKS5 command") { if debug { - eprintln!("[ostp] proxy client error: {err}"); + tracing::warn!("proxy client error: {err}"); } } } @@ -85,7 +85,7 @@ pub async fn run_local_socks5_proxy( if stream_id == 0 { if let ProxyToClientMsg::Close = msg { if debug { - eprintln!("[ostp] Resetting all active proxy streams on reconnect"); + tracing::info!("Resetting all active proxy streams on reconnect"); } for (_, tx) in active_streams.drain() { let _ = tx.send(ProxyToClientMsg::Close); @@ -200,7 +200,7 @@ async fn handle_proxy_client( }; if debug { - eprintln!("[ostp] proxy CONNECT stream_id={stream_id} target={target}"); + tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); } if matcher.should_bypass(&target, connect_timeout).await { return direct_connect_socks5(client, stream_id, &target, close_tx, debug).await; @@ -277,7 +277,7 @@ async fn handle_proxy_client( }; if debug { - eprintln!("[ostp] proxy CONNECT stream_id={stream_id} target={target}"); + tracing::info!("proxy CONNECT stream_id={stream_id} target={target}"); } if matcher.should_bypass(&target, connect_timeout).await { return direct_connect_http( @@ -333,7 +333,7 @@ async fn handle_proxy_client( Ok(0) => { let _ = event_tx.send(ProxyEvent::Close { stream_id }).await; if debug { - eprintln!("[ostp] proxy CLOSE stream_id={stream_id}"); + tracing::info!("proxy CLOSE stream_id={stream_id}"); } break; } @@ -346,7 +346,7 @@ async fn handle_proxy_client( Err(_) => { let _ = event_tx.send(ProxyEvent::Close { stream_id }).await; if debug { - eprintln!("[ostp] proxy CLOSE stream_id={stream_id}"); + tracing::info!("proxy CLOSE stream_id={stream_id}"); } break; } @@ -513,7 +513,7 @@ async fn direct_connect_socks5( debug: bool, ) -> Result<()> { if debug { - eprintln!("[ostp] proxy BYPASS stream_id={stream_id} target={target}"); + tracing::info!("proxy BYPASS stream_id={stream_id} target={target}"); } let mut remote = TcpStream::connect(target).await .with_context(|| format!("direct connect failed: {target}"))?; @@ -534,7 +534,7 @@ async fn direct_connect_http( debug: bool, ) -> Result<()> { if debug { - eprintln!("[ostp] proxy BYPASS stream_id={stream_id} target={target}"); + tracing::info!("proxy BYPASS stream_id={stream_id} target={target}"); } let mut remote = TcpStream::connect(target).await .with_context(|| format!("direct connect failed: {target}"))?; diff --git a/ostp-client/src/tunnel/wintun_handler.rs b/ostp-client/src/tunnel/wintun_handler.rs index 1ba616b..01a141d 100644 --- a/ostp-client/src/tunnel/wintun_handler.rs +++ b/ostp-client/src/tunnel/wintun_handler.rs @@ -41,7 +41,7 @@ pub async fn run_wintun_tunnel( let debug = config.debug; - eprintln!("[ostp] Initializing TUN tunnel..."); + tracing::info!("Initializing TUN tunnel..."); let exe = std::env::current_exe()?; let dir = exe.parent().ok_or_else(|| anyhow!("failed to get binary directory"))?; @@ -59,7 +59,7 @@ pub async fn run_wintun_tunnel( // 1. Delete stale TUN adapter if it exists from a previous run. // This prevents wintun from creating "ostp_tun 2", "ostp_tun 3", etc. - eprintln!("[ostp] Cleaning up stale TUN adapter..."); + tracing::info!("Cleaning up stale TUN adapter..."); let _ = Command::new("powershell") .creation_flags(CREATE_NO_WINDOW) .args(["-NoProfile", "-Command", &format!( @@ -79,7 +79,7 @@ pub async fn run_wintun_tunnel( .ok_or_else(|| anyhow!("Could not resolve host IP for routing exclusion"))?; let server_ip_str = server_ip.to_string(); - eprintln!("[ostp] Resolved server IP: {}", server_ip_str); + tracing::info!("Resolved server IP: {}", server_ip_str); // 3. Prepare routing and firewall setup script let current_exe = std::env::current_exe()?.to_string_lossy().into_owned(); @@ -105,7 +105,7 @@ pub async fn run_wintun_tunnel( // 4. Launch tun2socks + route setup IN PARALLEL to save ~3 seconds let proxy_url = format!("http://{}", config.local_proxy.bind_addr); - eprintln!("[ostp] Starting tun2socks (proxy={})", proxy_url); + tracing::info!("Starting tun2socks (proxy={})", proxy_url); // Spawn tun2socks immediately — it creates the adapter on its own let mut child = Command::new(&tun2socks_exe) @@ -151,7 +151,7 @@ pub async fn run_wintun_tunnel( if let Ok(out) = check { let status = String::from_utf8_lossy(&out.stdout).trim().to_string(); if debug { - eprintln!("[ostp] Adapter status: '{}'", status); + tracing::info!("Adapter status: '{}'", status); } if status == "Up" || status == "Disconnected" || !status.is_empty() { adapter_ready = true; @@ -161,14 +161,14 @@ pub async fn run_wintun_tunnel( } if !adapter_ready { - eprintln!("[ostp] WARNING: TUN adapter did not appear within timeout. Proceeding anyway."); + tracing::warn!("WARNING: TUN adapter did not appear within timeout. Proceeding anyway."); } // Wait for route setup to finish (should already be done by now) let _ = route_handle.await; // 6. Configure the adapter (IP, metric, MTU, DNS) - eprintln!("[ostp] Applying network configuration..."); + tracing::info!("Applying network configuration..."); let mut net_setup = format!( "netsh interface ipv4 set address name=\"{TUN_NAME}\" static 10.1.0.2 255.255.255.0 10.1.0.1\n\ netsh interface ipv4 set subinterface \"{TUN_NAME}\" mtu=1300 store=persistent\n\ @@ -177,7 +177,7 @@ pub async fn run_wintun_tunnel( if let Some(ref dns) = config.dns_server { if !dns.is_empty() { - eprintln!("[ostp] DNS server: {}", dns); + tracing::info!("DNS server: {}", dns); net_setup.push_str(&format!( "netsh interface ipv4 set dnsservers name=\"{TUN_NAME}\" static {} primary\n", dns )); @@ -189,7 +189,7 @@ pub async fn run_wintun_tunnel( .args(["-NoProfile", "-Command", &net_setup]) .output()?; - eprintln!("[ostp] TUN tunnel active. All traffic is routed through OSTP."); + tracing::info!("TUN tunnel active. All traffic is routed through OSTP."); // 7. Spawn debug log readers for tun2socks output let mut stdout = child.stdout.take(); @@ -202,7 +202,7 @@ pub async fn run_wintun_tunnel( if let Some(out) = stdout.take() { let reader = BufReader::new(out); for line in reader.lines().map_while(Result::ok) { - eprintln!("[tun2socks] {}", line); + tracing::debug!("tun2socks: {}", line); } } }); @@ -211,7 +211,7 @@ pub async fn run_wintun_tunnel( if let Some(err) = stderr.take() { let reader = BufReader::new(err); for line in reader.lines().map_while(Result::ok) { - eprintln!("[tun2socks err] {}", line); + tracing::warn!("tun2socks: {}", line); } } }); @@ -220,9 +220,9 @@ pub async fn run_wintun_tunnel( // 8. Wait for shutdown signal let _ = shutdown.changed().await; - eprintln!("[ostp] Deactivating TUN tunnel..."); + tracing::info!("Deactivating TUN tunnel..."); drop(_guard); - eprintln!("[ostp] TUN tunnel stopped."); + tracing::info!("TUN tunnel stopped."); Ok(()) } diff --git a/ostp-client/src/turn.rs b/ostp-client/src/turn.rs new file mode 100644 index 0000000..04f54a8 --- /dev/null +++ b/ostp-client/src/turn.rs @@ -0,0 +1,397 @@ +//! TURN (RFC 5766) allocation and channel binding for NAT traversal. +//! +//! Implements the minimal STUN/TURN message flow needed to allocate a relay +//! address and bind a channel to the OSTP server. All crypto (MD5, SHA-1, +//! HMAC-SHA1) is implemented inline to avoid external dependencies. + +use std::time::Duration; + +use anyhow::Result; +use tokio::net::UdpSocket; +use tokio::time::timeout; + +/// Real RFC-5766 TURN allocation with HMAC-SHA1 long-term credentials. +/// +/// Flow: +/// 1. Send Allocate (unauthenticated) -> get 401 with realm + nonce +/// 2. Compute HMAC-SHA1 key = MD5(username:realm:password) +/// 3. Re-send Allocate with MESSAGE-INTEGRITY +/// 4. Extract XOR-RELAYED-ADDRESS from success response +/// 5. Send ChannelBind to bind channel 0x4000 to the OSTP server addr +/// +/// Returns the relay address string like "1.2.3.4:12345". +pub async fn perform_turn_allocation( + socket: &UdpSocket, + turn_addr: &str, + username: &str, + password: &str, + ostp_server_addr: &str, +) -> Result { + use std::net::ToSocketAddrs; + + let turn_sock: std::net::SocketAddr = turn_addr + .to_socket_addrs() + .map_err(|e| anyhow::anyhow!("TURN DNS resolution failed: {e}"))? + .next() + .ok_or_else(|| anyhow::anyhow!("TURN addr resolved to nothing"))?; + + let transaction_id = { + use rand::Rng; + let mut id = [0u8; 12]; + rand::thread_rng().fill(&mut id); + id + }; + + // ── Step 1: unauthenticated Allocate ───────────────────────────── + // REQUESTED-TRANSPORT attr: 0x0019, value = 17 (UDP) + 3 reserved bytes + let req_transport = stun_attr(0x0019, &[17u8, 0, 0, 0]); + let alloc_req = build_stun_msg(0x0003, &transaction_id, &req_transport); + + socket.send_to(&alloc_req, turn_sock).await + .map_err(|e| anyhow::anyhow!("TURN send Allocate failed: {e}"))?; + + let mut buf = [0u8; 2048]; + let (n, _) = timeout(Duration::from_millis(3000), socket.recv_from(&mut buf)) + .await + .map_err(|_| anyhow::anyhow!("TURN Allocate response timed out"))? + .map_err(|e| anyhow::anyhow!("TURN recv failed: {e}"))?; + + let resp = &buf[..n]; + if resp.len() < 20 { + anyhow::bail!("TURN response too short"); + } + + let msg_type = u16::from_be_bytes([resp[0], resp[1]]); + + // 0x0113 = Allocate Error Response + if msg_type != 0x0113 { + anyhow::bail!("Expected TURN 401 error response, got type 0x{:04x}", msg_type); + } + + // Parse realm and nonce from the error response attributes + let mut realm: Option = None; + let mut nonce: Option = None; + { + let mut idx = 20usize; + while idx + 4 <= n { + let atype = u16::from_be_bytes([resp[idx], resp[idx + 1]]); + let alen = u16::from_be_bytes([resp[idx + 2], resp[idx + 3]]) as usize; + idx += 4; + if idx + alen > n { break; } + let val = &resp[idx..idx + alen]; + match atype { + 0x0014 => realm = Some(String::from_utf8_lossy(val).to_string()), // REALM + 0x0015 => nonce = Some(String::from_utf8_lossy(val).to_string()), // NONCE + _ => {} + } + idx += alen; + let pad = (4 - (alen % 4)) % 4; + idx += pad; + } + } + + let realm = realm.ok_or_else(|| anyhow::anyhow!("TURN 401: no REALM in response"))?; + let nonce = nonce.ok_or_else(|| anyhow::anyhow!("TURN 401: no NONCE in response"))?; + + // ── Step 2: Compute long-term credential key per RFC 5389 §15.4 ── + // key = MD5(username ":" realm ":" password) + let key_input = format!("{}:{}:{}", username, realm, password); + let key = md5_hash(key_input.as_bytes()); + + // HMAC-SHA1 of the message (MESSAGE-INTEGRITY attribute, RFC 5389 §15.4) + let mut attrs2 = Vec::new(); + attrs2.extend_from_slice(&stun_attr(0x0006, username.as_bytes())); // USERNAME + attrs2.extend_from_slice(&stun_attr(0x0014, realm.as_bytes())); // REALM + attrs2.extend_from_slice(&stun_attr(0x0015, nonce.as_bytes())); // NONCE + attrs2.extend_from_slice(&req_transport); // REQUESTED-TRANSPORT + + // For MESSAGE-INTEGRITY we need the full message length including the MI attr (24 bytes) + let mi_placeholder_len = attrs2.len() + 4 + 20; // +4 header, +20 HMAC-SHA1 + let mut msg_for_hmac = build_stun_msg(0x0003, &transaction_id, &attrs2); + // Set length field to include the upcoming MI attr + let new_len = (mi_placeholder_len - 20) as u16; // total attrs length including MI + msg_for_hmac[2..4].copy_from_slice(&new_len.to_be_bytes()); + // Append MI header (without value) + msg_for_hmac.extend_from_slice(&0x0008_u16.to_be_bytes()); // attr type + msg_for_hmac.extend_from_slice(&20_u16.to_be_bytes()); // attr len + + let hmac = hmac_sha1(&key, &msg_for_hmac); + let mut final_attrs = attrs2.clone(); + final_attrs.extend_from_slice(&stun_attr(0x0008, &hmac)); // MESSAGE-INTEGRITY + + let alloc_req2 = build_stun_msg(0x0003, &transaction_id, &final_attrs); + + socket.send_to(&alloc_req2, turn_sock).await + .map_err(|e| anyhow::anyhow!("TURN authenticated Allocate send failed: {e}"))?; + + let (n2, _) = timeout(Duration::from_millis(5000), socket.recv_from(&mut buf)) + .await + .map_err(|_| anyhow::anyhow!("TURN authenticated Allocate timed out"))? + .map_err(|e| anyhow::anyhow!("TURN recv2 failed: {e}"))?; + + let resp2 = &buf[..n2]; + if resp2.len() < 20 { + anyhow::bail!("TURN auth response too short"); + } + let msg_type2 = u16::from_be_bytes([resp2[0], resp2[1]]); + // 0x0103 = Allocate Success Response + if msg_type2 != 0x0103 { + anyhow::bail!("TURN Allocate auth failed, response type 0x{:04x}", msg_type2); + } + + // ── Step 3: Parse XOR-RELAYED-ADDRESS ──────────────────────────── + let relay_addr_str = { + let mut relayed: Option = None; + let mut idx = 20usize; + while idx + 4 <= n2 { + let atype = u16::from_be_bytes([resp2[idx], resp2[idx + 1]]); + let alen = u16::from_be_bytes([resp2[idx + 2], resp2[idx + 3]]) as usize; + idx += 4; + if idx + alen > n2 { break; } + let val = &resp2[idx..idx + alen]; + if atype == 0x0016 && alen >= 8 { // XOR-RELAYED-ADDRESS + let x_port = u16::from_be_bytes([val[2], val[3]]) ^ 0x2112; + let x_ip = [val[4], val[5], val[6], val[7]]; + let ip = std::net::Ipv4Addr::new( + x_ip[0] ^ 0x21, x_ip[1] ^ 0x12, x_ip[2] ^ 0xA4, x_ip[3] ^ 0x42, + ); + relayed = Some(format!("{}:{}", ip, x_port)); + } + idx += alen; + let pad = (4 - (alen % 4)) % 4; + idx += pad; + } + relayed.ok_or_else(|| anyhow::anyhow!("TURN: no XOR-RELAYED-ADDRESS in response"))? + }; + + // ── Step 4: ChannelBind to the OSTP server ──────────────────────── + let ostp_sock: std::net::SocketAddr = ostp_server_addr + .to_socket_addrs() + .map_err(|e| anyhow::anyhow!("OSTP server DNS resolution failed: {e}"))? + .next() + .ok_or_else(|| anyhow::anyhow!("OSTP server addr resolved to nothing"))?; + + let channel_number: u16 = 0x4000; + let mut peer_addr_attr = Vec::new(); + peer_addr_attr.push(0u8); // reserved + peer_addr_attr.push(0x01u8); // family IPv4 + peer_addr_attr.extend_from_slice(&(ostp_sock.port() ^ 0x2112).to_be_bytes()); // XOR port + if let std::net::IpAddr::V4(ipv4) = ostp_sock.ip() { + let octets = ipv4.octets(); + peer_addr_attr.push(octets[0] ^ 0x21); + peer_addr_attr.push(octets[1] ^ 0x12); + peer_addr_attr.push(octets[2] ^ 0xA4); + peer_addr_attr.push(octets[3] ^ 0x42); + } else { + anyhow::bail!("TURN ChannelBind: IPv6 OSTP server not yet supported"); + } + + let mut cb_attrs = Vec::new(); + // CHANNEL-NUMBER attr: 0x000C + cb_attrs.extend_from_slice(&stun_attr(0x000C, &[ + (channel_number >> 8) as u8, channel_number as u8, 0, 0 + ])); + // XOR-PEER-ADDRESS attr: 0x0012 + cb_attrs.extend_from_slice(&stun_attr(0x0012, &peer_addr_attr)); + cb_attrs.extend_from_slice(&stun_attr(0x0006, username.as_bytes())); + cb_attrs.extend_from_slice(&stun_attr(0x0014, realm.as_bytes())); + cb_attrs.extend_from_slice(&stun_attr(0x0015, nonce.as_bytes())); + + // Compute MESSAGE-INTEGRITY for ChannelBind too + let mi_len2 = cb_attrs.len() + 4 + 20; + let mut cb_for_hmac = build_stun_msg(0x0009, &transaction_id, &cb_attrs); + cb_for_hmac[2..4].copy_from_slice(&((mi_len2 - 20) as u16).to_be_bytes()); + cb_for_hmac.extend_from_slice(&0x0008_u16.to_be_bytes()); + cb_for_hmac.extend_from_slice(&20_u16.to_be_bytes()); + let cb_hmac = hmac_sha1(&key, &cb_for_hmac); + cb_attrs.extend_from_slice(&stun_attr(0x0008, &cb_hmac)); + + let cb_req = build_stun_msg(0x0009, &transaction_id, &cb_attrs); + socket.send_to(&cb_req, turn_sock).await + .map_err(|e| anyhow::anyhow!("TURN ChannelBind send failed: {e}"))?; + + let (n3, _) = timeout(Duration::from_millis(3000), socket.recv_from(&mut buf)) + .await + .map_err(|_| anyhow::anyhow!("TURN ChannelBind response timed out"))? + .map_err(|e| anyhow::anyhow!("TURN ChannelBind recv failed: {e}"))?; + + let resp3 = &buf[..n3]; + if resp3.len() < 4 { + anyhow::bail!("TURN ChannelBind response too short"); + } + let cb_resp_type = u16::from_be_bytes([resp3[0], resp3[1]]); + // 0x0109 = ChannelBind Success Response + if cb_resp_type != 0x0109 { + anyhow::bail!("TURN ChannelBind failed, response type 0x{:04x}", cb_resp_type); + } + + Ok(relay_addr_str) +} + +// ── STUN message helpers ───────────────────────────────────────────────────── + +fn build_stun_msg(msg_type: u16, tx_id: &[u8; 12], attrs: &[u8]) -> Vec { + let mut msg = Vec::with_capacity(20 + attrs.len()); + msg.extend_from_slice(&msg_type.to_be_bytes()); + msg.extend_from_slice(&(attrs.len() as u16).to_be_bytes()); + msg.extend_from_slice(&0x2112A442_u32.to_be_bytes()); // Magic Cookie + msg.extend_from_slice(tx_id); + msg.extend_from_slice(attrs); + msg +} + +fn stun_attr(attr_type: u16, value: &[u8]) -> Vec { + let mut out = Vec::new(); + out.extend_from_slice(&attr_type.to_be_bytes()); + out.extend_from_slice(&(value.len() as u16).to_be_bytes()); + out.extend_from_slice(value); + // Pad to 4-byte boundary + let pad = (4 - (value.len() % 4)) % 4; + out.extend(std::iter::repeat(0u8).take(pad)); + out +} + +// ── Cryptographic primitives (inline, zero external deps) ──────────────────── + +/// Pure-Rust MD5 hash (16 bytes). Used for TURN long-term credential key derivation. +fn md5_hash(input: &[u8]) -> [u8; 16] { + // RFC 1321 MD5 constants + const S: [u32; 64] = [ + 7,12,17,22, 7,12,17,22, 7,12,17,22, 7,12,17,22, + 5, 9,14,20, 5, 9,14,20, 5, 9,14,20, 5, 9,14,20, + 4,11,16,23, 4,11,16,23, 4,11,16,23, 4,11,16,23, + 6,10,15,21, 6,10,15,21, 6,10,15,21, 6,10,15,21, + ]; + const K: [u32; 64] = [ + 0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, + 0xa8304613, 0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, + 0x6b901122, 0xfd987193, 0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, + 0x265e5a51, 0xe9b6c7aa, 0xd62f105d, 0x02441453, 0xd8a1e681, 0xe7d3fbc8, + 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed, 0xa9e3e905, 0xfcefa3f8, + 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122, 0xfde5380c, + 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa, + 0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, + 0xf4292244, 0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, + 0xffeff47d, 0x85845dd1, 0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, + 0xf7537e82, 0xbd3af235, 0x2ad7d2bb, 0xeb86d391, + ]; + + let msg_len = input.len(); + let bit_len = (msg_len as u64) * 8; + + let mut padded = input.to_vec(); + padded.push(0x80); + while padded.len() % 64 != 56 { + padded.push(0); + } + padded.extend_from_slice(&bit_len.to_le_bytes()); + + let mut a0: u32 = 0x67452301; + let mut b0: u32 = 0xefcdab89; + let mut c0: u32 = 0x98badcfe; + let mut d0: u32 = 0x10325476; + + for chunk in padded.chunks(64) { + let mut m = [0u32; 16]; + for (i, item) in m.iter_mut().enumerate() { + *item = u32::from_le_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]); + } + let (mut a, mut b, mut c, mut d) = (a0, b0, c0, d0); + for i in 0..64usize { + let (f, g) = match i { + 0..=15 => ((b & c) | (!b & d), i), + 16..=31 => ((d & b) | (!d & c), (5*i + 1) % 16), + 32..=47 => (b ^ c ^ d, (3*i + 5) % 16), + _ => (c ^ (b | !d), (7*i) % 16), + }; + let temp = d; + d = c; + c = b; + b = b.wrapping_add((a.wrapping_add(f).wrapping_add(K[i]).wrapping_add(m[g])).rotate_left(S[i])); + a = temp; + } + a0 = a0.wrapping_add(a); + b0 = b0.wrapping_add(b); + c0 = c0.wrapping_add(c); + d0 = d0.wrapping_add(d); + } + + let mut result = [0u8; 16]; + result[0..4].copy_from_slice(&a0.to_le_bytes()); + result[4..8].copy_from_slice(&b0.to_le_bytes()); + result[8..12].copy_from_slice(&c0.to_le_bytes()); + result[12..16].copy_from_slice(&d0.to_le_bytes()); + result +} + +/// HMAC-SHA1 for TURN MESSAGE-INTEGRITY (RFC 2104 + RFC 5389 §15.4). +fn hmac_sha1(key: &[u8], message: &[u8]) -> [u8; 20] { + const BLOCK_SIZE: usize = 64; + + let mut k = [0u8; BLOCK_SIZE]; + if key.len() > BLOCK_SIZE { + let h = sha1_hash(key); + k[..20].copy_from_slice(&h); + } else { + k[..key.len()].copy_from_slice(key); + } + + let mut ipad = [0u8; BLOCK_SIZE]; + let mut opad = [0u8; BLOCK_SIZE]; + for i in 0..BLOCK_SIZE { + ipad[i] = k[i] ^ 0x36; + opad[i] = k[i] ^ 0x5C; + } + + let mut inner = ipad.to_vec(); + inner.extend_from_slice(message); + let inner_hash = sha1_hash(&inner); + + let mut outer = opad.to_vec(); + outer.extend_from_slice(&inner_hash); + sha1_hash(&outer) +} + +/// Pure-Rust SHA-1 (RFC 3174). +fn sha1_hash(input: &[u8]) -> [u8; 20] { + let msg_len = input.len(); + let bit_len = (msg_len as u64) * 8; + let mut padded = input.to_vec(); + padded.push(0x80); + while padded.len() % 64 != 56 { + padded.push(0); + } + padded.extend_from_slice(&bit_len.to_be_bytes()); + + let mut h: [u32; 5] = [0x67452301, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0]; + + for chunk in padded.chunks(64) { + let mut w = [0u32; 80]; + for i in 0..16 { + w[i] = u32::from_be_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]); + } + for i in 16..80 { + w[i] = (w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]).rotate_left(1); + } + let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]); + for i in 0..80usize { + let (f, k) = match i { + 0..=19 => ((b & c) | (!b & d), 0x5A827999u32), + 20..=39 => (b ^ c ^ d, 0x6ED9EBA1), + 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDC), + _ => (b ^ c ^ d, 0xCA62C1D6), + }; + let temp = a.rotate_left(5).wrapping_add(f).wrapping_add(e).wrapping_add(k).wrapping_add(w[i]); + e = d; d = c; c = b.rotate_left(30); b = a; a = temp; + } + h[0] = h[0].wrapping_add(a); h[1] = h[1].wrapping_add(b); + h[2] = h[2].wrapping_add(c); h[3] = h[3].wrapping_add(d); + h[4] = h[4].wrapping_add(e); + } + + let mut out = [0u8; 20]; + for (i, &v) in h.iter().enumerate() { + out[i*4..(i+1)*4].copy_from_slice(&v.to_be_bytes()); + } + out +} diff --git a/ostp-core/Cargo.toml b/ostp-core/Cargo.toml index 2589466..d5022d4 100644 --- a/ostp-core/Cargo.toml +++ b/ostp-core/Cargo.toml @@ -6,13 +6,11 @@ license.workspace = true [dependencies] anyhow.workspace = true -async-trait.workspace = true bytes.workspace = true chacha20poly1305.workspace = true rand.workspace = true snow.workspace = true thiserror.workspace = true tracing.workspace = true -x25519-dalek.workspace = true sha2.workspace = true hmac.workspace = true diff --git a/ostp-core/src/congestion.rs b/ostp-core/src/congestion.rs new file mode 100644 index 0000000..c25eedc --- /dev/null +++ b/ostp-core/src/congestion.rs @@ -0,0 +1,341 @@ +//! Congestion control for the OSTP protocol. +//! +//! Implements a simplified BBR-inspired algorithm that estimates bottleneck +//! bandwidth and minimum RTT to determine the optimal sending rate. +//! This replaces the fixed `retransmit_budget = 8` with an adaptive +//! congestion window that responds to network conditions. + +use std::collections::VecDeque; +use std::time::{Duration, Instant}; + +/// Congestion control state for a single OSTP session. +pub struct CongestionController { + /// Current congestion window in bytes (how much can be in-flight) + cwnd: u64, + /// Slow-start threshold in bytes + ssthresh: u64, + /// Current phase + phase: Phase, + /// Minimum RTT observed (used for BDP calculation) + min_rtt: Duration, + /// Maximum bandwidth observed (bytes/sec) + max_bandwidth: u64, + /// RTT samples for smoothing + rtt_samples: VecDeque, + /// Bandwidth samples + bw_samples: VecDeque, + /// Bytes currently in flight (unacknowledged) + bytes_in_flight: u64, + /// Total bytes acknowledged (for bandwidth estimation) + total_acked: u64, + /// Last time we received an ACK + last_ack_time: Instant, + /// Number of loss events in the current window + loss_count: u32, + /// Pacing rate: bytes per second + pacing_rate: u64, + /// MTU estimate (used for cwnd → packet count conversion) + mtu: u64, + /// Probe RTT phase timer + probe_rtt_timer: Option, + /// Min RTT expiry: re-probe after 10 seconds + min_rtt_stamp: Instant, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Phase { + /// Exponential growth until loss or ssthresh + SlowStart, + /// Probe bandwidth: cycle through pacing gains + ProbeBandwidth, + /// Periodically drain the queue to measure true min RTT + ProbeRtt, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct RttSample { + rtt: Duration, + time: Instant, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct BwSample { + bytes_per_sec: u64, + time: Instant, +} + +/// Maximum number of samples to keep for windowed min/max +const MAX_SAMPLES: usize = 32; +/// Initial congestion window: 10 packets × MTU +const INITIAL_CWND_PACKETS: u64 = 10; +/// Minimum cwnd: 2 packets +const MIN_CWND_PACKETS: u64 = 2; +/// Min RTT expiry window (after which we re-probe) +const MIN_RTT_EXPIRY: Duration = Duration::from_secs(10); +/// ProbeRTT drain duration +const PROBE_RTT_DURATION: Duration = Duration::from_millis(200); + +impl CongestionController { + pub fn new(mtu: u64) -> Self { + let now = Instant::now(); + let initial_cwnd = INITIAL_CWND_PACKETS * mtu; + Self { + cwnd: initial_cwnd, + ssthresh: u64::MAX, + phase: Phase::SlowStart, + min_rtt: Duration::from_millis(100), // Conservative initial estimate + max_bandwidth: 0, + rtt_samples: VecDeque::with_capacity(MAX_SAMPLES), + bw_samples: VecDeque::with_capacity(MAX_SAMPLES), + bytes_in_flight: 0, + total_acked: 0, + last_ack_time: now, + loss_count: 0, + pacing_rate: initial_cwnd * 10, // initial: ~10 windows/sec + mtu, + probe_rtt_timer: None, + min_rtt_stamp: now, + } + } + + /// Returns the current congestion window in bytes. + pub fn cwnd(&self) -> u64 { + self.cwnd + } + + /// Returns the current congestion window in packets. + pub fn cwnd_packets(&self) -> usize { + (self.cwnd / self.mtu).max(MIN_CWND_PACKETS) as usize + } + + /// Returns the current pacing rate in bytes/sec. + pub fn pacing_rate(&self) -> u64 { + self.pacing_rate + } + + /// Returns the smoothed RTT estimate. + pub fn smoothed_rtt(&self) -> Duration { + self.min_rtt + } + + /// Returns how many bytes can still be sent. + pub fn available_cwnd(&self) -> u64 { + self.cwnd.saturating_sub(self.bytes_in_flight) + } + + /// Returns the recommended retransmit budget per tick. + pub fn retransmit_budget(&self) -> usize { + // Allow retransmitting up to 1/4 of the cwnd in packets per tick + let budget = (self.cwnd_packets() / 4).max(2); + budget.min(64) // cap at 64 to prevent burst + } + + /// Check whether we can send more data. + pub fn can_send(&self) -> bool { + self.bytes_in_flight < self.cwnd + } + + /// Record that we sent `bytes` of data. + pub fn on_send(&mut self, bytes: u64) { + self.bytes_in_flight = self.bytes_in_flight.saturating_add(bytes); + } + + /// Record that `bytes` were acknowledged with the given RTT sample. + pub fn on_ack(&mut self, bytes: u64, rtt: Duration) { + let now = Instant::now(); + self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes); + self.total_acked = self.total_acked.saturating_add(bytes); + + // Update RTT + self.update_rtt(rtt, now); + + // Update bandwidth estimate + self.update_bandwidth(bytes, now); + + // State machine + match self.phase { + Phase::SlowStart => { + // Exponential growth: increase cwnd by acked bytes + self.cwnd = self.cwnd.saturating_add(bytes); + if self.cwnd >= self.ssthresh { + self.phase = Phase::ProbeBandwidth; + tracing::debug!(cwnd = self.cwnd, "congestion: exiting slow start"); + } + } + Phase::ProbeBandwidth => { + // BBR-style: target cwnd = BDP * gain + let bdp = self.bandwidth_delay_product(); + // Apply gain of 1.25 during probe bandwidth + let target = (bdp * 5 / 4).max(MIN_CWND_PACKETS * self.mtu); + // Smooth transition + if self.cwnd < target { + self.cwnd = self.cwnd.saturating_add(bytes * self.mtu / self.cwnd.max(1)); + } else { + self.cwnd = target; + } + } + Phase::ProbeRtt => { + // Drain down to 4 packets to measure true min RTT + self.cwnd = MIN_CWND_PACKETS * self.mtu * 2; + if let Some(timer) = self.probe_rtt_timer { + if now.duration_since(timer) >= PROBE_RTT_DURATION { + // ProbeRTT complete, return to ProbeBandwidth + self.phase = Phase::ProbeBandwidth; + self.probe_rtt_timer = None; + let bdp = self.bandwidth_delay_product(); + self.cwnd = bdp.max(MIN_CWND_PACKETS * self.mtu); + tracing::debug!(cwnd = self.cwnd, min_rtt = ?self.min_rtt, "congestion: probe RTT complete"); + } + } + } + } + + // Periodically enter ProbeRTT to refresh min_rtt + if now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY && self.phase != Phase::ProbeRtt { + self.phase = Phase::ProbeRtt; + self.probe_rtt_timer = Some(now); + tracing::debug!("congestion: entering probe RTT phase"); + } + + self.update_pacing_rate(); + self.last_ack_time = now; + } + + /// Record a loss event. + pub fn on_loss(&mut self, bytes_lost: u64) { + self.bytes_in_flight = self.bytes_in_flight.saturating_sub(bytes_lost); + self.loss_count += 1; + + match self.phase { + Phase::SlowStart => { + // Exit slow start, set ssthresh to half of cwnd + self.ssthresh = self.cwnd / 2; + self.cwnd = self.ssthresh.max(MIN_CWND_PACKETS * self.mtu); + self.phase = Phase::ProbeBandwidth; + tracing::debug!(cwnd = self.cwnd, ssthresh = self.ssthresh, "congestion: loss during slow start"); + } + Phase::ProbeBandwidth => { + // Multiplicative decrease: cwnd *= 0.7 (BBR-style, less aggressive than Cubic's 0.5) + self.cwnd = (self.cwnd * 7 / 10).max(MIN_CWND_PACKETS * self.mtu); + tracing::debug!(cwnd = self.cwnd, "congestion: loss, cwnd reduced"); + } + Phase::ProbeRtt => { + // Don't react to loss during ProbeRTT + } + } + + self.update_pacing_rate(); + } + + /// Called periodically to update state. + pub fn on_tick(&mut self) { + // Nothing special needed per-tick -- state updates happen on ACK/loss + } + + // ── Private ────────────────────────────────────────────────────────────── + + fn update_rtt(&mut self, rtt: Duration, now: Instant) { + // Track windowed minimum RTT + if rtt < self.min_rtt || now.duration_since(self.min_rtt_stamp) >= MIN_RTT_EXPIRY { + self.min_rtt = rtt; + self.min_rtt_stamp = now; + } + + // Keep sample history + self.rtt_samples.push_back(RttSample { rtt, time: now }); + while self.rtt_samples.len() > MAX_SAMPLES { + self.rtt_samples.pop_front(); + } + } + + fn update_bandwidth(&mut self, acked_bytes: u64, now: Instant) { + let elapsed = now.duration_since(self.last_ack_time); + if elapsed.as_micros() > 0 { + let bw = acked_bytes * 1_000_000 / elapsed.as_micros() as u64; + if bw > self.max_bandwidth { + self.max_bandwidth = bw; + } + self.bw_samples.push_back(BwSample { bytes_per_sec: bw, time: now }); + while self.bw_samples.len() > MAX_SAMPLES { + self.bw_samples.pop_front(); + } + } + } + + fn bandwidth_delay_product(&self) -> u64 { + // BDP = max_bandwidth * min_rtt + let bw = if self.max_bandwidth > 0 { + self.max_bandwidth + } else { + // Fallback: assume 10 Mbps + 1_250_000 + }; + let rtt_secs = self.min_rtt.as_secs_f64(); + (bw as f64 * rtt_secs) as u64 + } + + fn update_pacing_rate(&mut self) { + // Pacing rate = cwnd / min_rtt (with gain) + let rtt_us = self.min_rtt.as_micros().max(1) as u64; + self.pacing_rate = self.cwnd * 1_000_000 / rtt_us; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_initial_state() { + let cc = CongestionController::new(1200); + assert_eq!(cc.cwnd(), 12000); // 10 * 1200 + assert!(cc.can_send()); + assert_eq!(cc.cwnd_packets(), 10); + } + + #[test] + fn test_slow_start_growth() { + let mut cc = CongestionController::new(1200); + // Simulate sending and ACKing + cc.on_send(1200); + cc.on_ack(1200, Duration::from_millis(50)); + // cwnd should grow + assert!(cc.cwnd() > 12000); + } + + #[test] + fn test_loss_reduces_cwnd() { + let mut cc = CongestionController::new(1200); + let initial = cc.cwnd(); + cc.on_loss(1200); + assert!(cc.cwnd() < initial); + } + + #[test] + fn test_can_send_limits() { + let mut cc = CongestionController::new(1200); + // Send until cwnd is exhausted + for _ in 0..10 { + cc.on_send(1200); + } + assert!(!cc.can_send()); // cwnd exhausted + } + + #[test] + fn test_retransmit_budget() { + let cc = CongestionController::new(1200); + let budget = cc.retransmit_budget(); + assert!(budget >= 2); + assert!(budget <= 64); + } + + #[test] + fn test_rtt_tracking() { + let mut cc = CongestionController::new(1200); + cc.on_send(1200); + cc.on_ack(1200, Duration::from_millis(25)); + assert_eq!(cc.smoothed_rtt(), Duration::from_millis(25)); + } +} diff --git a/ostp-core/src/crypto/kex.rs b/ostp-core/src/crypto/kex.rs deleted file mode 100644 index 32c2c11..0000000 --- a/ostp-core/src/crypto/kex.rs +++ /dev/null @@ -1,66 +0,0 @@ -// ============================================================================= -// OSTP Hybrid Key Exchange — STUB / NOT IN USE -// ============================================================================= -// -// This module is a placeholder for future post-quantum key exchange. -// The actual key exchange is handled by the Noise NNpsk0 handshake in noise.rs. -// -// When ML-KEM (CRYSTALS-Kyber) support is added, this module will provide: -// 1. X25519 ephemeral DH (classical security) -// 2. ML-KEM-768 encapsulation (post-quantum security) -// 3. Combined shared secret = SHA-256(x25519_secret || ml_kem_secret) -// -// Until then, DO NOT use this module in production — it provides zero -// post-quantum security. The Noise handshake in noise.rs is the only -// active key exchange mechanism. -// ============================================================================= - -#![allow(dead_code)] - -use sha2::{Digest, Sha256}; - -/// Placeholder shared secret output. -/// NOT USED by the protocol — provided for future API compatibility only. -#[derive(Debug, Clone)] -pub struct HybridSharedSecret { - pub x25519_pubkey: [u8; 32], - pub pq_ciphertext: Vec, - pub combined_secret: [u8; 32], -} - -/// Placeholder hybrid key exchange. -/// The PQ component is a no-op stub. See module-level documentation. -pub struct HybridKex; - -impl HybridKex { - /// Generate a hybrid key exchange offer. - /// - /// # Security Warning - /// The post-quantum component is a **stub** — `pq_ciphertext` is all zeros. - /// This function exists solely for API scaffolding. Do not rely on it for - /// post-quantum security. - pub fn client_offer() -> HybridSharedSecret { - use rand::rngs::OsRng; - use x25519_dalek::{EphemeralSecret, PublicKey}; - - let secret = EphemeralSecret::random_from_rng(OsRng); - let pubkey = PublicKey::from(&secret); - - // TODO: Replace with ML-KEM-768 encapsulation (crate `ml-kem`) - let pq_ciphertext = vec![0_u8; 1088]; - - let mut hasher = Sha256::new(); - hasher.update(pubkey.as_bytes()); - hasher.update(&pq_ciphertext); - let digest = hasher.finalize(); - - let mut combined_secret = [0_u8; 32]; - combined_secret.copy_from_slice(&digest[..32]); - - HybridSharedSecret { - x25519_pubkey: *pubkey.as_bytes(), - pq_ciphertext, - combined_secret, - } - } -} diff --git a/ostp-core/src/crypto/mod.rs b/ostp-core/src/crypto/mod.rs index dfa5496..785d68e 100644 --- a/ostp-core/src/crypto/mod.rs +++ b/ostp-core/src/crypto/mod.rs @@ -1,10 +1,8 @@ pub mod aead; -pub mod kex; pub mod noise; pub mod obfuscation; pub use aead::SessionCipher; -pub use kex::{HybridSharedSecret, HybridKex}; pub use noise::{NoiseRole, NoiseSession}; pub use obfuscation::{ deobfuscate_header_inplace, deobfuscate_packet_inplace, obfuscate_packet_inplace, diff --git a/ostp-core/src/lib.rs b/ostp-core/src/lib.rs index d36fddd..ee31a2a 100644 --- a/ostp-core/src/lib.rs +++ b/ostp-core/src/lib.rs @@ -1,7 +1,9 @@ +pub mod congestion; pub mod crypto; pub mod framing; pub mod protocol; pub mod relay; +pub mod resumption; pub use crypto::NoiseRole; pub use framing::{TrafficProfile, PaddingStrategy}; diff --git a/ostp-core/src/protocol.rs b/ostp-core/src/protocol.rs index 5ca183d..75254a3 100644 --- a/ostp-core/src/protocol.rs +++ b/ostp-core/src/protocol.rs @@ -5,6 +5,7 @@ use thiserror::Error; use std::collections::{BTreeMap, VecDeque}; use std::time::{Duration, Instant}; +use crate::congestion::CongestionController; use crate::crypto::{NoiseRole, NoiseSession, SessionCipher}; use crate::framing::{AdaptivePadder, FrameHeader, FrameKind, FramedPacket, PaddingStrategy}; @@ -93,7 +94,9 @@ pub struct ProtocolMachine { /// evicted from sent_history, this timer detects the deadlock and skips /// the gap to restore liveness. last_recv_advance: Instant, - /// Key-derived handshake padding range + /// Congestion controller (BBR-inspired adaptive window) + cc: CongestionController, + /// Key-derived handshake padding range handshake_pad_min: usize, handshake_pad_max: usize, } @@ -138,7 +141,8 @@ impl ProtocolMachine { last_ack_sent: Instant::now(), last_nack_sent: Instant::now() - Duration::from_secs(1), last_recv_advance: Instant::now(), - handshake_pad_min: config.handshake_pad_min.max(8), + cc: CongestionController::new(1200), + handshake_pad_min: config.handshake_pad_min.max(8), handshake_pad_max: config.handshake_pad_max.max(config.handshake_pad_min + 16), }) } @@ -266,7 +270,7 @@ impl ProtocolMachine { if nonce < self.expected_recv_nonce { // Duplicate — the ACK we sent was likely lost or delayed. - eprintln!("[ostp] Duplicate frame nonce={} (expected {}), forcing ACK", nonce, self.expected_recv_nonce); + tracing::debug!("Duplicate frame nonce={} (expected {}), forcing ACK", nonce, self.expected_recv_nonce); if let Some(ack_frame) = self.force_build_ack()? { return Ok(ProtocolAction::SendDatagram(ack_frame)); } @@ -274,8 +278,7 @@ impl ProtocolMachine { } if nonce > self.expected_recv_nonce + self.max_reorder { - eprintln!( - "[ostp] Frame nonce={} exceeds max reorder window (expected={}, max_gap={}), sending NACK", + tracing::debug!("Frame nonce={} exceeds max reorder window (expected={}, max_gap={}), sending NACK", nonce, self.expected_recv_nonce, self.max_reorder ); if let Ok(nack_frame) = self.build_control_datagram( @@ -305,10 +308,13 @@ impl ProtocolMachine { if packet.payload.len() >= 8 { let req_nonce = u64::from_be_bytes(packet.payload[..8].try_into().unwrap()); if let Some(cached_frame) = self.lookup_sent_frame(req_nonce) { - eprintln!("[ostp] NACK received: retransmitting nonce={}", req_nonce); + tracing::debug!("NACK received: retransmitting nonce={}", req_nonce); + self.cc.on_loss(cached_frame.len() as u64); outbound_actions.push(ProtocolAction::SendDatagram(cached_frame)); } else { - eprintln!("[ostp] NACK received: nonce={} not found in sent_history (evicted)", req_nonce); + tracing::debug!("NACK received: nonce={} not found in sent_history (evicted)", req_nonce); + // Estimate ~1200 bytes lost for evicted frames + self.cc.on_loss(1200); } } } @@ -323,7 +329,7 @@ impl ProtocolMachine { ProtocolAction::DeliverApp(packet.header.stream_id, packet.payload) } FrameKind::Close => { - eprintln!("[ostp] Received Close frame, terminating session"); + tracing::info!("Received Close frame, terminating session"); self.state = OstpState::Closed; ProtocolAction::Noop } @@ -357,8 +363,7 @@ impl ProtocolMachine { if self.reorder_buffer.len() < self.max_reorder_buffer { self.reorder_buffer.insert(nonce, action); } else { - eprintln!( - "[ostp] Reorder buffer full ({}/{}), dropping frame nonce={}", + tracing::warn!("Reorder buffer full ({}/{}), dropping frame nonce={}", self.reorder_buffer.len(), self.max_reorder_buffer, nonce ); } @@ -497,8 +502,7 @@ impl ProtocolMachine { delivered += 1; } self.ack_pending = true; - eprintln!( - "[ostp] Gap recovery: skipped {} lost frames, delivered {} buffered frames (reorder_buf={})", + tracing::debug!("Gap recovery: skipped {} lost frames, delivered {} buffered frames (reorder_buf={})", skipped, delivered, self.reorder_buffer.len() ); } @@ -521,12 +525,12 @@ impl ProtocolMachine { self.sent_history.retain(|f| !f.is_retransmittable || f.retries <= grace); let evicted = before - self.sent_history.len(); if evicted > 0 { - eprintln!("[ostp] Evicted {} zombie frames from sent_history (remaining={})", evicted, self.sent_history.len()); + tracing::debug!("Evicted {} zombie frames from sent_history (remaining={})", evicted, self.sent_history.len()); } // ── Retransmit expired frames ──────────────────────────────── // Limit retransmits per tick to prevent bandwidth saturation - let mut retransmit_budget: usize = 8; + let mut retransmit_budget: usize = self.cc.retransmit_budget(); for frame in self.sent_history.iter_mut() { if retransmit_budget == 0 { break; @@ -657,8 +661,7 @@ impl ProtocolMachine { }); if self.sent_history.len() > self.max_sent_history { let overflow = self.sent_history.len() - self.max_sent_history; - eprintln!( - "[ostp] sent_history overflow: evicting {} oldest frames (cap={})", + tracing::debug!("sent_history overflow: evicting {} oldest frames (cap={})", overflow, self.max_sent_history ); while self.sent_history.len() > self.max_sent_history { @@ -668,7 +671,27 @@ impl ProtocolMachine { } fn drop_acked_frames(&mut self, ranges: &[(u64, u64)]) { + let now = Instant::now(); + let mut acked_bytes = 0u64; + let mut min_rtt = Duration::from_secs(60); + + // Compute RTT from the oldest acked frame's send timestamp + for frame in self.sent_history.iter() { + if nonce_in_ranges(frame.nonce, ranges) { + acked_bytes += frame.bytes.len() as u64; + let rtt = now.duration_since(frame.last_sent); + if rtt < min_rtt { + min_rtt = rtt; + } + } + } + self.sent_history.retain(|frame| !nonce_in_ranges(frame.nonce, ranges)); + + // Notify congestion controller + if acked_bytes > 0 { + self.cc.on_ack(acked_bytes, min_rtt); + } } } diff --git a/ostp-core/src/relay.rs b/ostp-core/src/relay.rs index 1a11462..bc0f1fc 100644 --- a/ostp-core/src/relay.rs +++ b/ostp-core/src/relay.rs @@ -84,3 +84,83 @@ fn decode_with_len(input: &[u8]) -> Result<&[u8]> { } Ok(&input[2..2 + len]) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connect_roundtrip() { + let msg = RelayMessage::Connect("example.com:443".to_string()); + let encoded = msg.encode(); + let decoded = RelayMessage::decode(&encoded).unwrap(); + match decoded { + RelayMessage::Connect(addr) => assert_eq!(addr, "example.com:443"), + _ => panic!("expected Connect"), + } + } + + #[test] + fn test_data_roundtrip() { + let data = vec![1, 2, 3, 4, 5]; + let msg = RelayMessage::Data(data.clone()); + let encoded = msg.encode(); + let decoded = RelayMessage::decode(&encoded).unwrap(); + match decoded { + RelayMessage::Data(d) => assert_eq!(d, data), + _ => panic!("expected Data"), + } + } + + #[test] + fn test_simple_tags() { + assert_eq!(RelayMessage::KeepAlive.encode(), vec![3]); + assert_eq!(RelayMessage::Close.encode(), vec![4]); + assert_eq!(RelayMessage::ConnectOk.encode(), vec![5]); + + assert!(matches!(RelayMessage::decode(&[3]).unwrap(), RelayMessage::KeepAlive)); + assert!(matches!(RelayMessage::decode(&[4]).unwrap(), RelayMessage::Close)); + assert!(matches!(RelayMessage::decode(&[5]).unwrap(), RelayMessage::ConnectOk)); + } + + #[test] + fn test_error_roundtrip() { + let msg = RelayMessage::Error("connection refused".to_string()); + let encoded = msg.encode(); + match RelayMessage::decode(&encoded).unwrap() { + RelayMessage::Error(e) => assert_eq!(e, "connection refused"), + _ => panic!("expected Error"), + } + } + + #[test] + fn test_ping_pong_roundtrip() { + let ts = 1234567890u64; + match RelayMessage::decode(&RelayMessage::Ping(ts).encode()).unwrap() { + RelayMessage::Ping(t) => assert_eq!(t, ts), + _ => panic!("expected Ping"), + } + match RelayMessage::decode(&RelayMessage::Pong(ts).encode()).unwrap() { + RelayMessage::Pong(t) => assert_eq!(t, ts), + _ => panic!("expected Pong"), + } + } + + #[test] + fn test_error_cases() { + assert!(RelayMessage::decode(&[]).is_err()); + assert!(RelayMessage::decode(&[255]).is_err()); + // Truncated: tag=1, len=5, only 2 bytes + assert!(RelayMessage::decode(&[1, 0, 5, b'a', b'b']).is_err()); + } + + #[test] + fn test_empty_data_roundtrip() { + let encoded = RelayMessage::Data(vec![]).encode(); + match RelayMessage::decode(&encoded).unwrap() { + RelayMessage::Data(d) => assert!(d.is_empty()), + _ => panic!("expected Data"), + } + } +} + diff --git a/ostp-core/src/resumption.rs b/ostp-core/src/resumption.rs new file mode 100644 index 0000000..d3753ad --- /dev/null +++ b/ostp-core/src/resumption.rs @@ -0,0 +1,307 @@ +//! 0-RTT Session Resumption for OSTP. +//! +//! When a client has previously connected to a server, it can cache +//! a "session ticket" that allows it to send encrypted data in the +//! very first packet — eliminating the handshake round-trip entirely. +//! +//! How it works: +//! 1. After a successful handshake, the server issues a SessionTicket +//! containing enough state to resume the session. +//! 2. The client stores the ticket locally (encrypted with the PSK). +//! 3. On reconnection, the client sends a ResumptionRequest with the +//! ticket + early data in the first packet. +//! 4. The server validates the ticket and immediately begins processing +//! data, achieving 0-RTT. +//! +//! Security considerations: +//! - Tickets have a TTL (default 3600s) to limit replay window. +//! - The server maintains a ticket nonce set to prevent replay. +//! - Early data is idempotent by protocol design (relay CONNECT is safe +//! because duplicate CONNECTs to the same target are no-ops). + +use std::collections::HashSet; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use sha2::{Digest, Sha256}; + +/// A session ticket that allows 0-RTT resumption. +#[derive(Debug, Clone)] +pub struct SessionTicket { + /// Unique ticket identifier (prevents replay) + pub ticket_id: [u8; 16], + /// Server session ID to resume + pub session_id: u32, + /// Derived cipher key for early data + pub cipher_key: [u8; 32], + /// Timestamp of issuance (seconds since epoch) + pub issued_at: u64, + /// Time-to-live in seconds + pub ttl: u64, +} + +/// Maximum ticket age (1 hour default) +const DEFAULT_TICKET_TTL: u64 = 3600; +/// Maximum tickets in the anti-replay set +const MAX_REPLAY_SET: usize = 10000; + +impl SessionTicket { + /// Create a new session ticket from the transport key material. + pub fn new(session_id: u32, transport_key: &[u8; 32], psk: &[u8; 32]) -> Self { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + + // Derive ticket ID from key material + timestamp + let mut hasher = Sha256::new(); + hasher.update(transport_key); + hasher.update(&now.to_be_bytes()); + hasher.update(b"ostp-ticket-id"); + let hash = hasher.finalize(); + let mut ticket_id = [0u8; 16]; + ticket_id.copy_from_slice(&hash[..16]); + + // Derive cipher key for early data from PSK + ticket + let mut key_hasher = Sha256::new(); + key_hasher.update(psk); + key_hasher.update(&ticket_id); + key_hasher.update(b"ostp-early-data-key"); + let cipher_key_hash = key_hasher.finalize(); + let mut cipher_key = [0u8; 32]; + cipher_key.copy_from_slice(&cipher_key_hash); + + Self { + ticket_id, + session_id, + cipher_key, + issued_at: now, + ttl: DEFAULT_TICKET_TTL, + } + } + + /// Check if the ticket has expired. + pub fn is_expired(&self) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + now > self.issued_at + self.ttl + } + + /// Serialize the ticket to bytes for storage/transmission. + /// Wire format: [ticket_id:16][session_id:4][cipher_key:32][issued_at:8][ttl:8] + pub fn to_bytes(&self) -> Vec { + let mut out = Vec::with_capacity(68); + out.extend_from_slice(&self.ticket_id); + out.extend_from_slice(&self.session_id.to_be_bytes()); + out.extend_from_slice(&self.cipher_key); + out.extend_from_slice(&self.issued_at.to_be_bytes()); + out.extend_from_slice(&self.ttl.to_be_bytes()); + out + } + + /// Deserialize a ticket from bytes. + pub fn from_bytes(data: &[u8]) -> Option { + if data.len() < 68 { + return None; + } + let mut ticket_id = [0u8; 16]; + ticket_id.copy_from_slice(&data[0..16]); + + let session_id = u32::from_be_bytes(data[16..20].try_into().ok()?); + + let mut cipher_key = [0u8; 32]; + cipher_key.copy_from_slice(&data[20..52]); + + let issued_at = u64::from_be_bytes(data[52..60].try_into().ok()?); + let ttl = u64::from_be_bytes(data[60..68].try_into().ok()?); + + Some(Self { + ticket_id, + session_id, + cipher_key, + issued_at, + ttl, + }) + } + + /// Encrypt the ticket with a PSK for client-side storage. + /// Uses a simple XOR cipher with HMAC-SHA256 derived key. + pub fn encrypt(&self, psk: &[u8; 32]) -> Vec { + let raw = self.to_bytes(); + let mut enc_key_hasher = Sha256::new(); + enc_key_hasher.update(psk); + enc_key_hasher.update(b"ostp-ticket-encryption"); + let enc_key = enc_key_hasher.finalize(); + + let mut encrypted = raw.clone(); + for (i, byte) in encrypted.iter_mut().enumerate() { + *byte ^= enc_key[i % 32]; + } + encrypted + } + + /// Decrypt a ticket from encrypted bytes. + pub fn decrypt(encrypted: &[u8], psk: &[u8; 32]) -> Option { + let mut enc_key_hasher = Sha256::new(); + enc_key_hasher.update(psk); + enc_key_hasher.update(b"ostp-ticket-encryption"); + let enc_key = enc_key_hasher.finalize(); + + let mut decrypted = encrypted.to_vec(); + for (i, byte) in decrypted.iter_mut().enumerate() { + *byte ^= enc_key[i % 32]; + } + Self::from_bytes(&decrypted) + } +} + +/// Server-side anti-replay guard for session tickets. +#[allow(dead_code)] +pub struct TicketValidator { + /// Set of consumed ticket IDs (prevents replay) + consumed: HashSet<[u8; 16]>, + /// PSK for ticket validation + psk: [u8; 32], + /// Maximum age for tickets + max_age: Duration, +} + +impl TicketValidator { + pub fn new(psk: [u8; 32]) -> Self { + Self { + consumed: HashSet::new(), + psk, + max_age: Duration::from_secs(DEFAULT_TICKET_TTL), + } + } + + /// Validate a ticket from the client. Returns the ticket if valid, + /// or None if expired, replayed, or invalid. + pub fn validate(&mut self, encrypted_ticket: &[u8]) -> Option { + let ticket = SessionTicket::decrypt(encrypted_ticket, &self.psk)?; + + // Check expiry + if ticket.is_expired() { + tracing::debug!("0-RTT ticket rejected: expired"); + return None; + } + + // Check replay + if self.consumed.contains(&ticket.ticket_id) { + tracing::warn!("0-RTT ticket rejected: replay detected"); + return None; + } + + // Accept and mark as consumed + self.consumed.insert(ticket.ticket_id); + + // Garbage collection: remove old entries when set grows too large + if self.consumed.len() > MAX_REPLAY_SET { + // Simple strategy: clear the entire set. This is safe because + // expired tickets would fail the expiry check anyway. + self.consumed.clear(); + self.consumed.insert(ticket.ticket_id); + tracing::debug!("0-RTT replay set cleared (overflow)"); + } + + tracing::debug!("0-RTT ticket accepted: session_id={}", ticket.session_id); + Some(ticket) + } + + /// Issue a new ticket for a completed session. + pub fn issue_ticket(&self, session_id: u32, transport_key: &[u8; 32]) -> SessionTicket { + SessionTicket::new(session_id, transport_key, &self.psk) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ticket_serialize_roundtrip() { + let psk = [42u8; 32]; + let key = [1u8; 32]; + let ticket = SessionTicket::new(12345, &key, &psk); + + let bytes = ticket.to_bytes(); + let restored = SessionTicket::from_bytes(&bytes).unwrap(); + + assert_eq!(ticket.ticket_id, restored.ticket_id); + assert_eq!(ticket.session_id, restored.session_id); + assert_eq!(ticket.cipher_key, restored.cipher_key); + assert_eq!(ticket.issued_at, restored.issued_at); + } + + #[test] + fn test_ticket_encrypt_decrypt() { + let psk = [42u8; 32]; + let key = [1u8; 32]; + let ticket = SessionTicket::new(99, &key, &psk); + + let encrypted = ticket.encrypt(&psk); + let decrypted = SessionTicket::decrypt(&encrypted, &psk).unwrap(); + + assert_eq!(ticket.ticket_id, decrypted.ticket_id); + assert_eq!(ticket.session_id, decrypted.session_id); + } + + #[test] + fn test_ticket_wrong_psk_fails() { + let psk = [42u8; 32]; + let wrong_psk = [99u8; 32]; + let key = [1u8; 32]; + let ticket = SessionTicket::new(1, &key, &psk); + let encrypted = ticket.encrypt(&psk); + + // Decrypting with wrong PSK produces garbage, from_bytes should + // still return Some but ticket_id won't match + let decrypted = SessionTicket::decrypt(&encrypted, &wrong_psk); + // It may parse but the data will be wrong + if let Some(d) = decrypted { + assert_ne!(d.ticket_id, ticket.ticket_id); + } + } + + #[test] + fn test_ticket_not_expired() { + let psk = [42u8; 32]; + let key = [1u8; 32]; + let ticket = SessionTicket::new(1, &key, &psk); + assert!(!ticket.is_expired()); + } + + #[test] + fn test_validator_replay_protection() { + let psk = [42u8; 32]; + let key = [1u8; 32]; + let mut validator = TicketValidator::new(psk); + + let ticket = validator.issue_ticket(1, &key); + let encrypted = ticket.encrypt(&psk); + + // First use should succeed + assert!(validator.validate(&encrypted).is_some()); + + // Replay should fail + assert!(validator.validate(&encrypted).is_none()); + } + + #[test] + fn test_validator_different_tickets() { + let psk = [42u8; 32]; + let mut validator = TicketValidator::new(psk); + + let ticket1 = validator.issue_ticket(1, &[1u8; 32]); + let ticket2 = validator.issue_ticket(2, &[2u8; 32]); + + assert!(validator.validate(&ticket1.encrypt(&psk)).is_some()); + assert!(validator.validate(&ticket2.encrypt(&psk)).is_some()); + } + + #[test] + fn test_truncated_ticket_fails() { + assert!(SessionTicket::from_bytes(&[0u8; 10]).is_none()); + } +} diff --git a/ostp-server/Cargo.toml b/ostp-server/Cargo.toml index 33fda7b..7f2d6ae 100644 --- a/ostp-server/Cargo.toml +++ b/ostp-server/Cargo.toml @@ -14,3 +14,5 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" rand.workspace = true socket2 = "0.6.3" +axum = "0.8" +tower-http = { version = "0.6", features = ["cors"] } diff --git a/ostp-server/src/api.rs b/ostp-server/src/api.rs new file mode 100644 index 0000000..5dfbf60 --- /dev/null +++ b/ostp-server/src/api.rs @@ -0,0 +1,378 @@ +//! Management REST API for OSTP server. +//! +//! Provides endpoints for third-party panels (like 3x-ui) to manage users, +//! query traffic statistics, and control the server. +//! +//! ## Endpoints +//! +//! - `GET /api/server/status` -- Server status (uptime, sessions, version) +//! - `GET /api/users` -- List all users with traffic stats +//! - `GET /api/users/:key` -- Single user stats +//! - `POST /api/users` -- Create new access key +//! - `DELETE /api/users/:key` -- Remove access key +//! - `PUT /api/users/:key/limit` -- Set traffic limit for a user +//! - `POST /api/users/:key/reset` -- Reset user traffic counters + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::time::Instant; + +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::IntoResponse, + routing::{delete, get, post, put}, + Json, Router, +}; +use serde::{Deserialize, Serialize}; +use tower_http::cors::{Any, CorsLayer}; + +use crate::dispatcher::{UserStats, UserStatsSnapshot}; + +// ── Shared state for API handlers ──────────────────────────────────────────── + +/// API server shared state. Held behind Arc for axum handlers. +#[derive(Clone)] +pub struct ApiState { + pub access_keys: Arc>>, + pub user_stats: Arc>>>, + pub start_time: Instant, + pub api_token: String, +} + +// ── API configuration ──────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ApiConfig { + pub enabled: bool, + pub bind: String, + pub token: String, +} + +impl Default for ApiConfig { + fn default() -> Self { + Self { + enabled: false, + bind: "127.0.0.1:9090".to_string(), + token: String::new(), + } + } +} + +// ── Request/Response types ─────────────────────────────────────────────────── + +#[derive(Serialize)] +struct ServerStatus { + version: &'static str, + uptime_seconds: u64, + active_users: usize, + total_users: usize, +} + +#[derive(Deserialize)] +pub struct CreateUserRequest { + pub access_key: Option, + pub limit_bytes: Option, +} + +#[derive(Deserialize)] +pub struct SetLimitRequest { + pub limit_bytes: Option, +} + +#[derive(Serialize)] +struct ApiResponse { + ok: bool, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +impl ApiResponse { + fn success(data: T) -> Json { + Json(Self { ok: true, data: Some(data), error: None }) + } +} + +fn api_error(msg: &str) -> (StatusCode, Json>) { + (StatusCode::BAD_REQUEST, Json(ApiResponse { ok: false, data: None, error: Some(msg.to_string()) })) +} + +fn api_unauthorized() -> (StatusCode, Json>) { + (StatusCode::UNAUTHORIZED, Json(ApiResponse { ok: false, data: None, error: Some("unauthorized".to_string()) })) +} + +// ── API router ─────────────────────────────────────────────────────────────── + +pub fn create_api_router(state: ApiState) -> Router { + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + Router::new() + .route("/api/server/status", get(handle_status)) + .route("/api/users", get(handle_list_users)) + .route("/api/users", post(handle_create_user)) + .route("/api/users/{key}", get(handle_get_user)) + .route("/api/users/{key}", delete(handle_delete_user)) + .route("/api/users/{key}/limit", put(handle_set_limit)) + .route("/api/users/{key}/reset", post(handle_reset_stats)) + .layer(cors) + .with_state(state) +} + +/// Start the Management API server on the configured bind address. +pub async fn start_api_server( + config: ApiConfig, + access_keys: Arc>>, + user_stats: Arc>>>, +) { + let state = ApiState { + access_keys, + user_stats, + start_time: Instant::now(), + api_token: config.token.clone(), + }; + + let app = create_api_router(state); + + let listener = match tokio::net::TcpListener::bind(&config.bind).await { + Ok(l) => l, + Err(e) => { + tracing::error!("Management API failed to bind on {}: {}", config.bind, e); + return; + } + }; + + tracing::info!("Management API listening on {}", config.bind); + + if let Err(e) = axum::serve(listener, app).await { + tracing::error!("Management API error: {}", e); + } +} + +// ── Middleware: token check ────────────────────────────────────────────────── + +fn check_token(state: &ApiState, headers: &axum::http::HeaderMap) -> bool { + if state.api_token.is_empty() { + return true; // No auth required if token is empty + } + match headers.get("authorization") { + Some(value) => { + let val = value.to_str().unwrap_or(""); + val == format!("Bearer {}", state.api_token) || val == state.api_token + } + None => false, + } +} + +// ── Handlers ───────────────────────────────────────────────────────────────── + +async fn handle_status( + State(state): State, + headers: axum::http::HeaderMap, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::(); + } + + let keys = state.access_keys.read().unwrap(); + let stats = state.user_stats.read().unwrap(); + let online = stats.values() + .filter(|us| { + let total = us.bytes_up.load(Ordering::Relaxed) + us.bytes_down.load(Ordering::Relaxed); + total > 0 + }) + .count(); + + let status = ServerStatus { + version: env!("CARGO_PKG_VERSION"), + uptime_seconds: state.start_time.elapsed().as_secs(), + active_users: online, + total_users: keys.len(), + }; + + (StatusCode::OK, ApiResponse::success(status)) +} + +async fn handle_list_users( + State(state): State, + headers: axum::http::HeaderMap, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::>(); + } + + let keys = state.access_keys.read().unwrap(); + let stats = state.user_stats.read().unwrap(); + + let mut users: Vec = keys.keys().map(|key| { + if let Some(us) = stats.get(key) { + UserStatsSnapshot { + access_key: key.clone(), + bytes_up: us.bytes_up.load(Ordering::Relaxed), + bytes_down: us.bytes_down.load(Ordering::Relaxed), + connections: us.connections.load(Ordering::Relaxed), + limit_bytes: us.limit_bytes, + online: true, // Simplified; real check requires session map + } + } else { + UserStatsSnapshot { + access_key: key.clone(), + bytes_up: 0, + bytes_down: 0, + connections: 0, + limit_bytes: None, + online: false, + } + } + }).collect(); + + users.sort_by(|a, b| b.bytes_down.cmp(&a.bytes_down)); + + (StatusCode::OK, ApiResponse::success(users)) +} + +async fn handle_get_user( + State(state): State, + headers: axum::http::HeaderMap, + Path(key): Path, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::(); + } + + let keys = state.access_keys.read().unwrap(); + if !keys.contains_key(&key) { + return api_error("user not found"); + } + + let stats = state.user_stats.read().unwrap(); + let snapshot = if let Some(us) = stats.get(&key) { + UserStatsSnapshot { + access_key: key.clone(), + bytes_up: us.bytes_up.load(Ordering::Relaxed), + bytes_down: us.bytes_down.load(Ordering::Relaxed), + connections: us.connections.load(Ordering::Relaxed), + limit_bytes: us.limit_bytes, + online: true, + } + } else { + UserStatsSnapshot { + access_key: key.clone(), + bytes_up: 0, + bytes_down: 0, + connections: 0, + limit_bytes: None, + online: false, + } + }; + + (StatusCode::OK, ApiResponse::success(snapshot)) +} + +async fn handle_create_user( + State(state): State, + headers: axum::http::HeaderMap, + Json(body): Json, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::(); + } + + let key = body.access_key.unwrap_or_else(|| { + use rand::RngCore; + let mut buf = [0u8; 16]; + rand::thread_rng().fill_bytes(&mut buf); + buf.iter().map(|b| format!("{:02x}", b)).collect() + }); + + { + let mut keys = state.access_keys.write().unwrap(); + keys.insert(key.clone(), ()); + } + + if let Some(limit) = body.limit_bytes { + let mut stats = state.user_stats.write().unwrap(); + stats.insert(key.clone(), Arc::new(UserStats::new(Some(limit)))); + } + + tracing::info!("API: created user key {}", &key[..8.min(key.len())]); + (StatusCode::OK, ApiResponse::success(key)) +} + +async fn handle_delete_user( + State(state): State, + headers: axum::http::HeaderMap, + Path(key): Path, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::(); + } + + let removed = { + let mut keys = state.access_keys.write().unwrap(); + keys.remove(&key).is_some() + }; + + if removed { + let mut stats = state.user_stats.write().unwrap(); + stats.remove(&key); + tracing::info!("API: deleted user key {}", &key[..8.min(key.len())]); + } + + (StatusCode::OK, ApiResponse::success(removed)) +} + +async fn handle_set_limit( + State(state): State, + headers: axum::http::HeaderMap, + Path(key): Path, + Json(body): Json, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::(); + } + + let keys = state.access_keys.read().unwrap(); + if !keys.contains_key(&key) { + return api_error("user not found"); + } + drop(keys); + + let mut stats = state.user_stats.write().unwrap(); + let entry = stats.entry(key.clone()) + .or_insert_with(|| Arc::new(UserStats::new(body.limit_bytes))); + + *entry = Arc::new(UserStats { + bytes_up: AtomicU64::new(entry.bytes_up.load(Ordering::Relaxed)), + bytes_down: AtomicU64::new(entry.bytes_down.load(Ordering::Relaxed)), + connections: AtomicU64::new(entry.connections.load(Ordering::Relaxed)), + limit_bytes: body.limit_bytes, + created_at: entry.created_at, + }); + + (StatusCode::OK, ApiResponse::success(true)) +} + +async fn handle_reset_stats( + State(state): State, + headers: axum::http::HeaderMap, + Path(key): Path, +) -> impl IntoResponse { + if !check_token(&state, &headers) { + return api_unauthorized::(); + } + + let mut stats = state.user_stats.write().unwrap(); + if let Some(us) = stats.get(&key) { + let limit = us.limit_bytes; + stats.insert(key.clone(), Arc::new(UserStats::new(limit))); + (StatusCode::OK, ApiResponse::success(true)) + } else { + api_error("user not found") + } +} diff --git a/ostp-server/src/dispatcher.rs b/ostp-server/src/dispatcher.rs index 16cd866..63d7e99 100644 --- a/ostp-server/src/dispatcher.rs +++ b/ostp-server/src/dispatcher.rs @@ -4,9 +4,10 @@ use ostp_core::{OstpEvent, ProtocolAction, ProtocolConfig, ProtocolMachine}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::{Arc, RwLock}; +use std::sync::atomic::{AtomicU64, Ordering}; /// Maximum number of concurrent authenticated sessions. -/// Excess handshake attempts are silently dropped — no response, no state allocated. +/// Excess handshake attempts are silently dropped -- no response, no state allocated. const MAX_SESSIONS: usize = 1024; pub enum DispatchOutcome { @@ -18,11 +19,54 @@ pub enum DispatchOutcome { }, } +/// Per-user traffic statistics. +pub struct UserStats { + pub bytes_up: AtomicU64, + pub bytes_down: AtomicU64, + pub connections: AtomicU64, + pub limit_bytes: Option, + pub created_at: std::time::SystemTime, +} + +impl UserStats { + pub fn new(limit: Option) -> Self { + Self { + bytes_up: AtomicU64::new(0), + bytes_down: AtomicU64::new(0), + connections: AtomicU64::new(0), + limit_bytes: limit, + created_at: std::time::SystemTime::now(), + } + } + + pub fn is_over_limit(&self) -> bool { + if let Some(limit) = self.limit_bytes { + let total = self.bytes_up.load(Ordering::Relaxed) + + self.bytes_down.load(Ordering::Relaxed); + total >= limit + } else { + false + } + } +} + +/// Snapshot of user stats for API responses. +#[derive(Debug, Clone, serde::Serialize)] +pub struct UserStatsSnapshot { + pub access_key: String, + pub bytes_up: u64, + pub bytes_down: u64, + pub connections: u64, + pub limit_bytes: Option, + pub online: bool, +} + pub struct PeerState { pub machine: ProtocolMachine, pub last_addr: SocketAddr, pub obfuscation_key: [u8; 8], pub last_seen: std::time::Instant, + pub access_key: String, } pub struct Dispatcher { @@ -30,11 +74,13 @@ pub struct Dispatcher { addr_to_session: HashMap, machine_config: ProtocolConfig, access_keys: Arc>>, + user_stats: Arc>>>, replay_cache: std::collections::HashMap, u64>, roaming_tokens: f64, last_token_regen: std::time::Instant, } +#[allow(dead_code)] impl Dispatcher { pub fn new(machine_config: ProtocolConfig, access_keys: Arc>>) -> Self { Self { @@ -42,12 +88,67 @@ impl Dispatcher { addr_to_session: HashMap::new(), machine_config, access_keys, + user_stats: Arc::new(RwLock::new(HashMap::new())), replay_cache: std::collections::HashMap::new(), roaming_tokens: 50.0, last_token_regen: std::time::Instant::now(), } } + /// Returns a shared reference to user stats for the Management API. + pub fn user_stats_ref(&self) -> Arc>>> { + self.user_stats.clone() + } + + /// Snapshot all user stats for API responses. + pub fn snapshot_all_users(&self) -> Vec { + let stats = self.user_stats.read().unwrap(); + let online_keys: std::collections::HashSet = self.peer_machines.values() + .map(|ps| ps.access_key.clone()) + .collect(); + stats.iter().map(|(key, us)| UserStatsSnapshot { + access_key: key.clone(), + bytes_up: us.bytes_up.load(Ordering::Relaxed), + bytes_down: us.bytes_down.load(Ordering::Relaxed), + connections: us.connections.load(Ordering::Relaxed), + limit_bytes: us.limit_bytes, + online: online_keys.contains(key), + }).collect() + } + + /// Get or create stats entry for a user key. + fn get_or_create_user_stats(&self, key: &str) -> Arc { + let stats = self.user_stats.read().unwrap(); + if let Some(existing) = stats.get(key) { + return existing.clone(); + } + drop(stats); + let mut stats = self.user_stats.write().unwrap(); + stats.entry(key.to_string()) + .or_insert_with(|| Arc::new(UserStats::new(None))) + .clone() + } + + /// Set traffic limit for a user. + pub fn set_user_limit(&self, key: &str, limit: Option) { + let mut stats = self.user_stats.write().unwrap(); + let entry = stats.entry(key.to_string()) + .or_insert_with(|| Arc::new(UserStats::new(limit))); + // Replace the entry with new limit (stats reset) + *entry = Arc::new(UserStats { + bytes_up: AtomicU64::new(entry.bytes_up.load(Ordering::Relaxed)), + bytes_down: AtomicU64::new(entry.bytes_down.load(Ordering::Relaxed)), + connections: AtomicU64::new(entry.connections.load(Ordering::Relaxed)), + limit_bytes: limit, + created_at: entry.created_at, + }); + } + + /// Active session count. + pub fn active_sessions(&self) -> usize { + self.peer_machines.len() + } + pub fn on_datagram(&mut self, peer: SocketAddr, packet: Bytes) -> Result { if packet.len() < 4 { return Ok(DispatchOutcome::Unauthorized); @@ -100,17 +201,21 @@ impl Dispatcher { if let Some(session_id) = session_id_opt { if let Some(peer_state) = self.peer_machines.get_mut(&session_id) { if peer_state.last_addr != peer { - eprintln!("[ostp] Client roamed: session {} from {} to {}", session_id, peer_state.last_addr, peer); + tracing::info!("Client roamed: session {} from {} to {}", session_id, peer_state.last_addr, peer); self.addr_to_session.remove(&peer_state.last_addr); } peer_state.last_addr = peer; peer_state.last_seen = std::time::Instant::now(); self.addr_to_session.insert(peer, session_id); + // Track inbound bytes per user + let key = peer_state.access_key.clone(); + track_user_bytes_up(&self.user_stats, &key, packet.len() as u64); + let action = match peer_state.machine.on_event(OstpEvent::Inbound(packet)) { Ok(a) => a, Err(e) => { - eprintln!("[ostp] Protocol error for session {}: {}", session_id, e); + tracing::warn!("Protocol error for session {}: {}", session_id, e); return Ok(DispatchOutcome::Unauthorized); } }; @@ -175,7 +280,7 @@ impl Dispatcher { let mut machine = match ProtocolMachine::new(cfg) { Ok(m) => m, Err(e) => { - eprintln!("[ostp] Failed to create protocol machine for key trial: {}", e); + tracing::warn!("Failed to create protocol machine for key trial: {}", e); continue; } }; @@ -212,17 +317,17 @@ impl Dispatcher { let drift = (now as i64 - ts as i64).abs(); if drift > 300 { - eprintln!("[ostp] Handshake rejected: timestamp drift {}s exceeds 300s limit (peer={})", drift, peer); + tracing::warn!("Handshake rejected: timestamp drift {}s exceeds 300s limit (peer={})", drift, peer); continue; } if !self.replay_cache.contains_key(&payload.to_vec()) { if self.replay_cache.len() >= 100_000 { - eprintln!("[ostp] Replay cache full (100000 entries), rejecting handshake from {}", peer); + tracing::warn!("Replay cache full (100000 entries), rejecting handshake from {}", peer); return Ok(DispatchOutcome::Unauthorized); } if self.peer_machines.len() >= MAX_SESSIONS { - eprintln!("[ostp] Max sessions reached ({}), rejecting handshake from {}", MAX_SESSIONS, peer); + tracing::warn!("Max sessions reached ({}), rejecting handshake from {}", MAX_SESSIONS, peer); return Ok(DispatchOutcome::Unauthorized); } @@ -230,16 +335,26 @@ impl Dispatcher { machine.set_session_keys(candidate_session_id, secrets.obfuscation_key); + // Track per-user connection count + let user_stats = self.get_or_create_user_stats(&candidate_key); + user_stats.connections.fetch_add(1, Ordering::Relaxed); + + // Check traffic limit before accepting + if user_stats.is_over_limit() { + tracing::warn!("User {} exceeded traffic limit, rejecting handshake from {}", candidate_key, peer); + return Ok(DispatchOutcome::Unauthorized); + } + self.peer_machines.insert(candidate_session_id, PeerState { machine, last_addr: peer, obfuscation_key: secrets.obfuscation_key, last_seen: std::time::Instant::now(), + access_key: candidate_key.clone(), }); self.addr_to_session.insert(peer, candidate_session_id); - eprintln!( - "[ostp] New session authenticated: sid={} peer={} (active_sessions={}, replay_cache={})", + tracing::info!("New session authenticated: sid={} peer={} (active_sessions={}, replay_cache={})", candidate_session_id, peer, self.peer_machines.len(), self.replay_cache.len() ); @@ -265,8 +380,13 @@ impl Dispatcher { }; let addr = peer_state.last_addr; + let key = peer_state.access_key.clone(); match peer_state.machine.on_event(OstpEvent::Outbound(stream_id, payload))? { - ProtocolAction::SendDatagram(frame) => Ok(Some((frame, addr))), + ProtocolAction::SendDatagram(frame) => { + // Track outbound bytes per user + track_user_bytes_down(&self.user_stats, &key, frame.len() as u64); + Ok(Some((frame, addr))) + } _ => Ok(None), } } @@ -293,7 +413,7 @@ impl Dispatcher { // Clear expired sessions from internal state for sid in &expired { - eprintln!("[ostp] Session {} expired (inactive >5min), releasing", sid); + tracing::info!("Session {} expired (inactive >5min), releasing", sid); self.drop_session(*sid); } @@ -302,7 +422,7 @@ impl Dispatcher { let action = match peer_state.machine.on_event(OstpEvent::Tick) { Ok(a) => a, Err(e) => { - eprintln!("[ostp] Tick error for session: {}", e); + tracing::warn!("Tick error for session: {}", e); continue; } }; @@ -332,3 +452,40 @@ impl Dispatcher { } } } + +// Free functions to avoid borrow-checker conflicts when tracking stats +// while holding a mutable reference to peer_machines. + +fn get_or_create_stats( + user_stats: &Arc>>>, + key: &str, +) -> Arc { + { + let stats = user_stats.read().unwrap(); + if let Some(existing) = stats.get(key) { + return existing.clone(); + } + } + let mut stats = user_stats.write().unwrap(); + stats.entry(key.to_string()) + .or_insert_with(|| Arc::new(UserStats::new(None))) + .clone() +} + +fn track_user_bytes_up( + user_stats: &Arc>>>, + key: &str, + bytes: u64, +) { + let stats = get_or_create_stats(user_stats, key); + stats.bytes_up.fetch_add(bytes, Ordering::Relaxed); +} + +fn track_user_bytes_down( + user_stats: &Arc>>>, + key: &str, + bytes: u64, +) { + let stats = get_or_create_stats(user_stats, key); + stats.bytes_down.fetch_add(bytes, Ordering::Relaxed); +} diff --git a/ostp-server/src/fallback.rs b/ostp-server/src/fallback.rs new file mode 100644 index 0000000..61abaa2 --- /dev/null +++ b/ostp-server/src/fallback.rs @@ -0,0 +1,87 @@ +//! Fallback TCP server for anti-DPI camouflage. +//! +//! When a connection arrives that doesn't match the OSTP protocol +//! (e.g., a DPI probe, web spider, or direct HTTP request), +//! it gets transparently proxied to a fallback web server (e.g., nginx). +//! +//! This makes the OSTP server indistinguishable from a regular web server +//! during active probing. + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +/// Fallback server configuration. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct FallbackConfig { + /// Enable fallback TCP listener + pub enabled: bool, + /// TCP listen address (e.g., "0.0.0.0:443" or "0.0.0.0:80") + pub listen: String, + /// Target to proxy to (e.g., "127.0.0.1:8080" for local nginx) + pub target: String, +} + +/// Start the fallback TCP proxy server. +pub async fn start_fallback_server(config: FallbackConfig) { + let listener = match TcpListener::bind(&config.listen).await { + Ok(l) => l, + Err(e) => { + tracing::error!("fallback server failed to bind on {}: {}", config.listen, e); + return; + } + }; + + tracing::info!("fallback server listening on {} -> {}", config.listen, config.target); + + loop { + match listener.accept().await { + Ok((client, peer_addr)) => { + let target = config.target.clone(); + tokio::spawn(async move { + if let Err(e) = proxy_connection(client, &target).await { + tracing::debug!(peer = %peer_addr, "fallback proxy error: {}", e); + } + }); + } + Err(e) => { + tracing::warn!("fallback accept error: {}", e); + } + } + } +} + +async fn proxy_connection(mut client: TcpStream, target: &str) -> anyhow::Result<()> { + let mut upstream = TcpStream::connect(target).await?; + + let (mut client_read, mut client_write) = client.split(); + let (mut upstream_read, mut upstream_write) = upstream.split(); + + let client_to_upstream = async { + let mut buf = vec![0u8; 8192]; + loop { + let n = client_read.read(&mut buf).await?; + if n == 0 { break; } + upstream_write.write_all(&buf[..n]).await?; + } + upstream_write.shutdown().await?; + Ok::<_, anyhow::Error>(()) + }; + + let upstream_to_client = async { + let mut buf = vec![0u8; 8192]; + loop { + let n = upstream_read.read(&mut buf).await?; + if n == 0 { break; } + client_write.write_all(&buf[..n]).await?; + } + client_write.shutdown().await?; + Ok::<_, anyhow::Error>(()) + }; + + tokio::select! { + r = client_to_upstream => { r?; } + r = upstream_to_client => { r?; } + } + + Ok(()) +} diff --git a/ostp-server/src/lib.rs b/ostp-server/src/lib.rs index cdf99ad..42c5567 100644 --- a/ostp-server/src/lib.rs +++ b/ostp-server/src/lib.rs @@ -1,6 +1,3 @@ -mod dispatcher; -mod signal; - use anyhow::Result; use bytes::Bytes; use std::collections::HashMap; @@ -8,13 +5,24 @@ use std::net::IpAddr; use dispatcher::{DispatchOutcome, Dispatcher}; use ostp_core::relay::RelayMessage; -use ostp_core::{NoiseRole, PaddingStrategy, ProtocolConfig}; use signal::wait_for_shutdown_signal; -use tokio::io::AsyncReadExt; -use tokio::net::{TcpStream, UdpSocket}; +use tokio::net::UdpSocket; use tokio::sync::mpsc; use tokio::time::{interval, Duration, Instant}; +mod dispatcher; +pub mod outbound; +pub mod api; +pub mod fallback; +mod relay; +mod signal; + +pub use outbound::{OutboundAction, OutboundConfig, OutboundRule}; +pub use api::ApiConfig; +pub use fallback::FallbackConfig; + +// ── Internal event types ───────────────────────────────────────────────────── + #[derive(Debug, Clone)] #[allow(dead_code)] enum UiCommand { @@ -24,7 +32,7 @@ enum UiCommand { #[allow(dead_code)] #[derive(Debug, Clone)] -enum UiEvent { +pub(crate) enum UiEvent { #[allow(dead_code)] PeerSeen { peer: IpAddr }, #[allow(dead_code)] Rx { peer: IpAddr, bytes: usize }, @@ -36,33 +44,19 @@ enum UiEvent { KeyCount(usize), } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OutboundAction { - Proxy, - Direct, +pub(crate) struct RemoteState { + pub data_tx: mpsc::UnboundedSender, + pub cancel_tx: mpsc::Sender<()>, } -#[derive(Debug, Clone)] -pub struct OutboundRule { - pub domain_suffix: Vec, - pub ip_cidr: Vec, - pub action: OutboundAction, -} - -#[derive(Debug, Clone)] -pub struct OutboundConfig { - pub enabled: bool, - pub protocol: String, - pub address: String, - pub port: u16, - pub rules: Vec, - pub default_action: OutboundAction, -} +// ── Public API ─────────────────────────────────────────────────────────────── pub async fn run_server( - bind_addr: String, + bind_addrs: Vec, access_keys: Vec, outbound: Option, + api_config: Option, + fallback_config: Option, debug: bool, ) -> Result<()> { let mut keys_map = HashMap::new(); @@ -106,7 +100,7 @@ pub async fn run_server( } let mut keys_lock = shared_keys_clone.write().unwrap(); *keys_lock = new_keys; - eprintln!("[ostp] Hot-reloaded {} access keys from config.json", keys_lock.len()); + tracing::info!("Hot-reloaded {} access keys from config.json", keys_lock.len()); } } } @@ -116,17 +110,24 @@ pub async fn run_server( } }); - let addr = bind_addr.parse::().map_err(|e| anyhow::anyhow!("invalid bind addr: {}", e))?; - let domain = if addr.is_ipv6() { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 }; - let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?; - let _ = sock.set_recv_buffer_size(33554432); // 32MB - let _ = sock.set_send_buffer_size(33554432); // 32MB - let actual_recv = sock.recv_buffer_size().unwrap_or(0); - let actual_send = sock.send_buffer_size().unwrap_or(0); - eprintln!("[ostp] UDP socket buffers: recv={}KB send={}KB", actual_recv / 1024, actual_send / 1024); - sock.bind(&addr.into())?; - sock.set_nonblocking(true)?; - let socket = UdpSocket::from_std(sock.into())?; + let mut sockets = Vec::new(); + for bind_addr in &bind_addrs { + let addr = bind_addr.parse::() + .map_err(|e| anyhow::anyhow!("invalid bind addr '{}': {}", bind_addr, e))?; + let domain = if addr.is_ipv6() { socket2::Domain::IPV6 } else { socket2::Domain::IPV4 }; + let sock = socket2::Socket::new(domain, socket2::Type::DGRAM, Some(socket2::Protocol::UDP))?; + let _ = sock.set_recv_buffer_size(33554432); + let _ = sock.set_send_buffer_size(33554432); + sock.bind(&addr.into())?; + sock.set_nonblocking(true)?; + let udp_sock = UdpSocket::from_std(sock.into())?; + tracing::info!("UDP socket bound to {}", bind_addr); + sockets.push(std::sync::Arc::new(udp_sock)); + } + if sockets.is_empty() { anyhow::bail!("no bind addresses specified"); } + let primary_socket = sockets[0].clone(); + + use ostp_core::{NoiseRole, PaddingStrategy, ProtocolConfig}; let protocol_config = ProtocolConfig { role: NoiseRole::Responder, psk: [0u8; 32], @@ -141,18 +142,36 @@ pub async fn run_server( rto_ms: 100, max_retries: 8, max_sent_history: 32768, - // Defaults — overridden per-session by dispatcher using derive_all_secrets() + // Defaults -- overridden per-session by dispatcher using derive_all_secrets() handshake_pad_min: 32, handshake_pad_max: 128, }; let dispatcher = Dispatcher::new(protocol_config, shared_keys.clone()); + // Spawn Management API if configured + if let Some(api_cfg) = api_config { + if api_cfg.enabled { + let api_keys = shared_keys.clone(); + let api_stats = dispatcher.user_stats_ref(); + tokio::spawn(async move { + api::start_api_server(api_cfg, api_keys, api_stats).await; + }); + } + } + + // Spawn Fallback TCP proxy if configured + if let Some(fb_cfg) = fallback_config { + if fb_cfg.enabled { + tokio::spawn(async move { + fallback::start_fallback_server(fb_cfg).await; + }); + } + } + let (_ui_cmd_tx, ui_cmd_rx) = mpsc::unbounded_channel::(); let (ui_event_tx, mut ui_event_rx) = mpsc::unbounded_channel::(); - let max_datagram_size = 65535; - // Headless event logger tokio::spawn(async move { while let Some(ev) = ui_event_rx.recv().await { @@ -167,15 +186,15 @@ pub async fn run_server( || msg.starts_with("Relay CLOSE") || msg.starts_with("Relay error"); if debug || is_essential { - eprintln!("[ostp] {msg}"); + tracing::info!("{msg}"); } } UiEvent::KeyCreated { key } => { - eprintln!("[ostp] Access key created: {key}"); + tracing::info!("Access key created: {key}"); } UiEvent::UnauthorizedProbe { peer, bytes } => { if debug { - eprintln!("[ostp] Unauthorized probe from {peer} ({bytes} bytes)"); + tracing::debug!("Unauthorized probe from {peer} ({bytes} bytes)"); } } UiEvent::PeerSeen { .. } => {} @@ -185,31 +204,28 @@ pub async fn run_server( }); let key_count = shared_keys.read().unwrap().len(); - eprintln!("[ostp] Listening on {bind_addr} ({key_count} access keys loaded)"); - eprintln!("[ostp] ARQ config: max_reorder=16384, reorder_buf=8192, sent_history=32768, rto=100ms"); + tracing::info!(listeners = bind_addrs.len(), keys = key_count, "server started"); + tracing::info!("ARQ config: max_reorder=16384, reorder_buf=8192, sent_history=32768, rto=100ms"); tokio::select! { - res = run_server_loop(socket, dispatcher, max_datagram_size, ui_cmd_rx, ui_event_tx, shared_keys, outbound, debug) => { + res = run_server_loop(primary_socket, sockets, dispatcher, ui_cmd_rx, ui_event_tx, shared_keys, outbound, debug) => { if let Err(e) = res { - eprintln!("[ostp] Server error: {e}"); + tracing::error!("Server error: {e}"); } } _ = wait_for_shutdown_signal() => { - eprintln!("[ostp] Shutdown signal received"); + tracing::info!("Shutdown signal received"); } } Ok(()) } -struct RemoteState { - data_tx: mpsc::UnboundedSender, - cancel_tx: mpsc::Sender<()>, -} +// ── Server main loop ───────────────────────────────────────────────────────── async fn run_server_loop( - socket: UdpSocket, + primary_socket: std::sync::Arc, + sockets: Vec>, mut dispatcher: Dispatcher, - _max_datagram_size: usize, mut ui_cmd_rx: mpsc::UnboundedReceiver, ui_event_tx: mpsc::UnboundedSender, shared_keys: std::sync::Arc>>, @@ -217,28 +233,31 @@ async fn run_server_loop( debug: bool, ) -> Result<()> { let mut remotes: HashMap<(u32, u16), RemoteState> = HashMap::new(); - // Unbounded channel: bounded(10000) caused TCP-reader tasks to fail under Speedtest load - // when 50+ streams competed for slots. Backpressure is managed at the relay layer instead. let (stream_tx, mut stream_rx) = mpsc::unbounded_channel::<(u32, u16, Vec)>(); let (connect_tx, mut connect_rx) = mpsc::unbounded_channel::<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>(); - let socket = std::sync::Arc::new(socket); + let socket = primary_socket; + // Spawn a recv task for each socket, all feeding into the same channel let (udp_tx, mut udp_rx) = mpsc::channel(10000); - let socket_clone = socket.clone(); - tokio::spawn(async move { - let mut buf = vec![0_u8; 65535]; - loop { - match socket_clone.recv_from(&mut buf).await { - Ok((size, peer)) => { - let packet = Bytes::copy_from_slice(&buf[..size]); - if udp_tx.send((packet, peer)).await.is_err() { - break; + for sock in &sockets { + let sock_clone = sock.clone(); + let tx = udp_tx.clone(); + tokio::spawn(async move { + let mut buf = vec![0_u8; 65535]; + loop { + match sock_clone.recv_from(&mut buf).await { + Ok((size, peer)) => { + let packet = Bytes::copy_from_slice(&buf[..size]); + if tx.send((packet, peer)).await.is_err() { + break; + } } + Err(_) => break, } - Err(_) => break, } - } - }); + }); + } + drop(udp_tx); // Drop the original sender so the channel closes when all tasks end if debug { let _ = ui_event_tx.send(UiEvent::Log("Server loop started".to_string())); @@ -302,7 +321,7 @@ async fn run_server_loop( "Deliver app payload sid={session_id} stream={stream_id} bytes={}", payload.len() ))); - handle_relay_message( + relay::handle_relay_message( peer_addr, session_id, stream_id, @@ -326,12 +345,12 @@ async fn run_server_loop( } Some((session_id, stream_id, data)) = stream_rx.recv() => { if data.is_empty() { - let _ = send_relay_to_stream(session_id, stream_id, RelayMessage::Close, &mut dispatcher, &socket, &ui_event_tx).await; + let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::Close, &mut dispatcher, &socket, &ui_event_tx).await; if let Some(state) = remotes.remove(&(session_id, stream_id)) { let _ = state.cancel_tx.try_send(()); } } else { - let _ = send_relay_to_stream(session_id, stream_id, RelayMessage::Data(data), &mut dispatcher, &socket, &ui_event_tx).await; + let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::Data(data), &mut dispatcher, &socket, &ui_event_tx).await; } } Some((session_id, stream_id, target, res)) = connect_rx.recv() => { @@ -347,12 +366,12 @@ async fn run_server_loop( } }); remotes.insert((session_id, stream_id), RemoteState { data_tx, cancel_tx }); - let _ = send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, &mut dispatcher, &socket, &ui_event_tx).await; + let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::ConnectOk, &mut dispatcher, &socket, &ui_event_tx).await; let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT ok for [{session_id}:{stream_id}] -> {target}"))); } Err(err) => { let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT failed for [{session_id}:{stream_id}] -> {target}: {err}"))); - let _ = send_relay_to_stream(session_id, stream_id, RelayMessage::Error(format!("connect failed: {err}")), &mut dispatcher, &socket, &ui_event_tx).await; + let _ = relay::send_relay_to_stream(session_id, stream_id, RelayMessage::Error(format!("connect failed: {err}")), &mut dispatcher, &socket, &ui_event_tx).await; } } } @@ -389,330 +408,3 @@ async fn run_server_loop( Ok(()) } - -async fn handle_relay_message( - _peer_addr: std::net::SocketAddr, - session_id: u32, - stream_id: u16, - payload: Bytes, - dispatcher: &mut Dispatcher, - socket: &UdpSocket, - remotes: &mut HashMap<(u32, u16), RemoteState>, - ui_event_tx: &mpsc::UnboundedSender, - stream_tx: mpsc::UnboundedSender<(u32, u16, Vec)>, - connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, - outbound: Option, - debug: bool, -) -> Result<()> { - match RelayMessage::decode(&payload)? { - RelayMessage::Connect(target) => { - let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT start for [{session_id}:{stream_id}] -> {target}"))); - let target_clone = target.clone(); - let connect_tx_clone = connect_tx.clone(); - let stream_tx_clone = stream_tx.clone(); - let outbound_clone = outbound.clone(); - tokio::spawn(async move { - let stream_res = connect_target(&target_clone, outbound_clone.as_ref(), debug).await; - match stream_res { - Ok(stream) => { - let (mut reader, writer) = stream.into_split(); - let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); - tokio::spawn(async move { - let mut buf = [0_u8; 4096]; - loop { - tokio::select! { - _ = cancel_rx.recv() => break, - read_res = reader.read(&mut buf) => { - match read_res { - Ok(0) | Err(_) => { - let _ = stream_tx_clone.send((session_id, stream_id, Vec::new())); - break; - } - Ok(n) => { - if stream_tx_clone.send((session_id, stream_id, buf[..n].to_vec())).is_err() { - break; - } - } - } - } - } - } - }); - let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Ok((writer, cancel_tx)))); - } - Err(e) => { - let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Err(e.to_string()))); - } - } - }); - } - RelayMessage::Data(data) => { - if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) { - let _ = remote.data_tx.send(bytes::Bytes::from(data)); - } else { - let _ = ui_event_tx.send(UiEvent::Log(format!("Relay DATA for unknown stream [{session_id}:{stream_id}] ({})", data.len()))); - } - } - RelayMessage::KeepAlive => {} - RelayMessage::Close => { - if let Some(state) = remotes.remove(&(session_id, stream_id)) { - let _ = state.cancel_tx.try_send(()); - let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CLOSE [{session_id}:{stream_id}]"))); - } - } - RelayMessage::ConnectOk => {} - RelayMessage::Error(msg) => { - let _ = ui_event_tx.send(UiEvent::Log(format!("Relay error from [{session_id}:{stream_id}]: {msg}"))); - } - RelayMessage::Ping(ts) => { - send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx).await?; - } - RelayMessage::Pong(_) => {} - } - Ok(()) -} - -async fn send_relay_to_stream( - session_id: u32, - stream_id: u16, - msg: RelayMessage, - dispatcher: &mut Dispatcher, - socket: &UdpSocket, - ui_event_tx: &mpsc::UnboundedSender, -) -> Result<()> { - let payload = Bytes::from(msg.encode()); - if let Some((frame, peer_addr)) = dispatcher.outbound_to_session(session_id, stream_id, payload)? { - let response_len = frame.len(); - let _ = socket.send_to(&frame, peer_addr).await?; - let _ = ui_event_tx.send(UiEvent::Tx { - peer: peer_addr.ip(), - bytes: response_len, - }); - } - Ok(()) -} - -async fn connect_target( - target: &str, - outbound: Option<&OutboundConfig>, - debug: bool, -) -> Result { - let connect_timeout = Duration::from_secs(10); - if let Some(outbound) = outbound { - if outbound.enabled { - let action = select_outbound_action(target, outbound, debug).await; - if action == OutboundAction::Proxy { - let proxy_addr = format!("{}:{}", outbound.address, outbound.port); - return match outbound.protocol.as_str() { - "socks5" => connect_via_socks5(&proxy_addr, target).await, - "http" => connect_via_http(&proxy_addr, target).await, - _ => tokio::time::timeout(connect_timeout, TcpStream::connect(target)) - .await - .map_err(|_| anyhow::anyhow!("connect timeout ({}s): {}", connect_timeout.as_secs(), target))? - .map_err(Into::into), - }; - } - } - } - - tokio::time::timeout(connect_timeout, TcpStream::connect(target)) - .await - .map_err(|_| anyhow::anyhow!("connect timeout ({}s): {}", connect_timeout.as_secs(), target))? - .map_err(Into::into) -} - -async fn select_outbound_action( - target: &str, - outbound: &OutboundConfig, - debug: bool, -) -> OutboundAction { - let (host, port) = match split_host_port(target) { - Some(v) => v, - None => return outbound.default_action, - }; - - let mut matched = None; - for rule in &outbound.rules { - if rule.domain_suffix.is_empty() && rule.ip_cidr.is_empty() { - continue; - } - if match_domain_rule(&host, &rule.domain_suffix) { - matched = Some(rule.action); - break; - } - if match_ip_rule(&host, port, &rule.ip_cidr).await { - matched = Some(rule.action); - break; - } - } - - let action = matched.unwrap_or(outbound.default_action); - if debug { - eprintln!("[ostp] Outbound routing: target={target} action={action:?}"); - } - action -} - -fn match_domain_rule(host: &str, suffixes: &[String]) -> bool { - if suffixes.is_empty() { - return false; - } - let host = host.trim_end_matches('.').to_lowercase(); - suffixes.iter().any(|suffix| { - let suffix = suffix.trim().trim_start_matches('.').to_lowercase(); - !suffix.is_empty() && (host == suffix || host.ends_with(&format!(".{suffix}"))) - }) -} - -async fn match_ip_rule(host: &str, port: u16, cidrs: &[String]) -> bool { - if cidrs.is_empty() { - return false; - } - let parsed: Vec = cidrs.iter().filter_map(|c| parse_cidr(c)).collect(); - if parsed.is_empty() { - return false; - } - if let Ok(ip) = host.parse::() { - return parsed.iter().any(|cidr| cidr.contains(&ip)); - } - - match tokio::net::lookup_host((host, port)).await { - Ok(addrs) => addrs.into_iter().any(|addr| parsed.iter().any(|cidr| cidr.contains(&addr.ip()))), - Err(_) => false, - } -} - -async fn connect_via_socks5(proxy_addr: &str, target: &str) -> Result { - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - let mut stream = TcpStream::connect(proxy_addr).await?; - stream.write_all(&[0x05, 0x01, 0x00]).await?; - let mut reply = [0u8; 2]; - stream.read_exact(&mut reply).await?; - if reply != [0x05, 0x00] { - anyhow::bail!("SOCKS5 auth not accepted"); - } - - let (host, port) = split_host_port(target).ok_or_else(|| anyhow::anyhow!("invalid target"))?; - let mut req = Vec::new(); - req.extend_from_slice(&[0x05, 0x01, 0x00]); - if let Ok(ip) = host.parse::() { - match ip { - std::net::IpAddr::V4(v4) => { - req.push(0x01); - req.extend_from_slice(&v4.octets()); - } - std::net::IpAddr::V6(v6) => { - req.push(0x04); - req.extend_from_slice(&v6.octets()); - } - } - } else { - req.push(0x03); - req.push(host.len() as u8); - req.extend_from_slice(host.as_bytes()); - } - req.extend_from_slice(&port.to_be_bytes()); - stream.write_all(&req).await?; - - let mut header = [0u8; 4]; - stream.read_exact(&mut header).await?; - if header[1] != 0x00 { - anyhow::bail!("SOCKS5 connect failed: 0x{:02x}", header[1]); - } - - let addr_len = match header[3] { - 0x01 => 4, - 0x04 => 16, - 0x03 => { - let mut len = [0u8; 1]; - stream.read_exact(&mut len).await?; - len[0] as usize - } - _ => 0, - }; - if addr_len > 0 { - let mut skip = vec![0u8; addr_len + 2]; - stream.read_exact(&mut skip).await?; - } - - Ok(stream) -} - -async fn connect_via_http(proxy_addr: &str, target: &str) -> Result { - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - let mut stream = TcpStream::connect(proxy_addr).await?; - let request = format!("CONNECT {target} HTTP/1.1\r\nHost: {target}\r\n\r\n"); - stream.write_all(request.as_bytes()).await?; - - let mut buf = vec![0u8; 1024]; - let n = stream.read(&mut buf).await?; - let response = String::from_utf8_lossy(&buf[..n]); - if !response.starts_with("HTTP/1.1 200") && !response.starts_with("HTTP/1.0 200") { - anyhow::bail!("HTTP CONNECT failed: {response}"); - } - Ok(stream) -} - -enum Cidr { - V4(u32, u8), - V6(u128, u8), -} - -impl Cidr { - fn contains(&self, ip: &std::net::IpAddr) -> bool { - match (self, ip) { - (Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => { - let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) }; - let ip = u32::from_be_bytes(addr.octets()); - (ip & mask) == (*net & mask) - } - (Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => { - let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) }; - let ip = u128::from_be_bytes(addr.octets()); - (ip & mask) == (*net & mask) - } - _ => false, - } - } -} - -fn parse_cidr(value: &str) -> Option { - let value = value.trim(); - if value.is_empty() { - return None; - } - if let Some((addr_str, bits_str)) = value.split_once('/') { - let bits: u8 = bits_str.parse().ok()?; - if let Ok(addr) = addr_str.parse::() { - return match addr { - std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits.min(32))), - std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits.min(128))), - }; - } - } - if let Ok(addr) = value.parse::() { - return match addr { - std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), 32)), - std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), 128)), - }; - } - None -} - -fn split_host_port(target: &str) -> Option<(String, u16)> { - if let Some((host, port)) = target.rsplit_once(':') { - if host.starts_with('[') && host.ends_with(']') { - let host = host.trim_start_matches('[').trim_end_matches(']').to_string(); - let port = port.parse().ok()?; - return Some((host, port)); - } - if host.contains(':') { - return None; - } - let port = port.parse().ok()?; - return Some((host.to_string(), port)); - } - None -} diff --git a/ostp-server/src/outbound.rs b/ostp-server/src/outbound.rs new file mode 100644 index 0000000..2b6980a --- /dev/null +++ b/ostp-server/src/outbound.rs @@ -0,0 +1,323 @@ +use anyhow::Result; +use tokio::net::TcpStream; +use tokio::time::Duration; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OutboundAction { + Proxy, + Direct, +} + +#[derive(Debug, Clone)] +pub struct OutboundRule { + pub domain_suffix: Vec, + pub ip_cidr: Vec, + pub action: OutboundAction, +} + +#[derive(Debug, Clone)] +pub struct OutboundConfig { + pub enabled: bool, + pub protocol: String, + pub address: String, + pub port: u16, + pub rules: Vec, + pub default_action: OutboundAction, +} + +// ── Target connection with outbound routing ────────────────────────────────── + +pub async fn connect_target( + target: &str, + outbound: Option<&OutboundConfig>, + debug: bool, +) -> Result { + let connect_timeout = Duration::from_secs(10); + if let Some(outbound) = outbound { + if outbound.enabled { + let action = select_outbound_action(target, outbound, debug).await; + if action == OutboundAction::Proxy { + let proxy_addr = format!("{}:{}", outbound.address, outbound.port); + return match outbound.protocol.as_str() { + "socks5" => connect_via_socks5(&proxy_addr, target).await, + "http" => connect_via_http(&proxy_addr, target).await, + _ => tokio::time::timeout(connect_timeout, TcpStream::connect(target)) + .await + .map_err(|_| anyhow::anyhow!("connect timeout ({}s): {}", connect_timeout.as_secs(), target))? + .map_err(Into::into), + }; + } + } + } + + tokio::time::timeout(connect_timeout, TcpStream::connect(target)) + .await + .map_err(|_| anyhow::anyhow!("connect timeout ({}s): {}", connect_timeout.as_secs(), target))? + .map_err(Into::into) +} + +// ── Rule matching ──────────────────────────────────────────────────────────── + +async fn select_outbound_action( + target: &str, + outbound: &OutboundConfig, + debug: bool, +) -> OutboundAction { + let (host, port) = match split_host_port(target) { + Some(v) => v, + None => return outbound.default_action, + }; + + let mut matched = None; + for rule in &outbound.rules { + if rule.domain_suffix.is_empty() && rule.ip_cidr.is_empty() { + continue; + } + if match_domain_rule(&host, &rule.domain_suffix) { + matched = Some(rule.action); + break; + } + if match_ip_rule(&host, port, &rule.ip_cidr).await { + matched = Some(rule.action); + break; + } + } + + let action = matched.unwrap_or(outbound.default_action); + if debug { + tracing::debug!("Outbound routing: target={target} action={action:?}"); + } + action +} + +fn match_domain_rule(host: &str, suffixes: &[String]) -> bool { + if suffixes.is_empty() { + return false; + } + let host = host.trim_end_matches('.').to_lowercase(); + suffixes.iter().any(|suffix| { + let suffix = suffix.trim().trim_start_matches('.').to_lowercase(); + !suffix.is_empty() && (host == suffix || host.ends_with(&format!(".{suffix}"))) + }) +} + +async fn match_ip_rule(host: &str, port: u16, cidrs: &[String]) -> bool { + if cidrs.is_empty() { + return false; + } + let parsed: Vec = cidrs.iter().filter_map(|c| parse_cidr(c)).collect(); + if parsed.is_empty() { + return false; + } + if let Ok(ip) = host.parse::() { + return parsed.iter().any(|cidr| cidr.contains(&ip)); + } + + match tokio::net::lookup_host((host, port)).await { + Ok(addrs) => addrs.into_iter().any(|addr| parsed.iter().any(|cidr| cidr.contains(&addr.ip()))), + Err(_) => false, + } +} + +// ── SOCKS5 / HTTP CONNECT upstream proxy ───────────────────────────────────── + +async fn connect_via_socks5(proxy_addr: &str, target: &str) -> Result { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let mut stream = TcpStream::connect(proxy_addr).await?; + stream.write_all(&[0x05, 0x01, 0x00]).await?; + let mut reply = [0u8; 2]; + stream.read_exact(&mut reply).await?; + if reply != [0x05, 0x00] { + anyhow::bail!("SOCKS5 auth not accepted"); + } + + let (host, port) = split_host_port(target).ok_or_else(|| anyhow::anyhow!("invalid target"))?; + let mut req = Vec::new(); + req.extend_from_slice(&[0x05, 0x01, 0x00]); + if let Ok(ip) = host.parse::() { + match ip { + std::net::IpAddr::V4(v4) => { + req.push(0x01); + req.extend_from_slice(&v4.octets()); + } + std::net::IpAddr::V6(v6) => { + req.push(0x04); + req.extend_from_slice(&v6.octets()); + } + } + } else { + req.push(0x03); + req.push(host.len() as u8); + req.extend_from_slice(host.as_bytes()); + } + req.extend_from_slice(&port.to_be_bytes()); + stream.write_all(&req).await?; + + let mut header = [0u8; 4]; + stream.read_exact(&mut header).await?; + if header[1] != 0x00 { + anyhow::bail!("SOCKS5 connect failed: 0x{:02x}", header[1]); + } + + let addr_len = match header[3] { + 0x01 => 4, + 0x04 => 16, + 0x03 => { + let mut len = [0u8; 1]; + stream.read_exact(&mut len).await?; + len[0] as usize + } + _ => 0, + }; + if addr_len > 0 { + let mut skip = vec![0u8; addr_len + 2]; + stream.read_exact(&mut skip).await?; + } + + Ok(stream) +} + +async fn connect_via_http(proxy_addr: &str, target: &str) -> Result { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let mut stream = TcpStream::connect(proxy_addr).await?; + let request = format!("CONNECT {target} HTTP/1.1\r\nHost: {target}\r\n\r\n"); + stream.write_all(request.as_bytes()).await?; + + let mut buf = vec![0u8; 1024]; + let n = stream.read(&mut buf).await?; + let response = String::from_utf8_lossy(&buf[..n]); + if !response.starts_with("HTTP/1.1 200") && !response.starts_with("HTTP/1.0 200") { + anyhow::bail!("HTTP CONNECT failed: {response}"); + } + Ok(stream) +} + +// ── CIDR utilities ─────────────────────────────────────────────────────────── + +enum Cidr { + V4(u32, u8), + V6(u128, u8), +} + +impl Cidr { + fn contains(&self, ip: &std::net::IpAddr) -> bool { + match (self, ip) { + (Cidr::V4(net, bits), std::net::IpAddr::V4(addr)) => { + let mask = if *bits == 0 { 0 } else { u32::MAX << (32 - bits) }; + let ip = u32::from_be_bytes(addr.octets()); + (ip & mask) == (*net & mask) + } + (Cidr::V6(net, bits), std::net::IpAddr::V6(addr)) => { + let mask = if *bits == 0 { 0 } else { u128::MAX << (128 - bits) }; + let ip = u128::from_be_bytes(addr.octets()); + (ip & mask) == (*net & mask) + } + _ => false, + } + } +} + +fn parse_cidr(value: &str) -> Option { + let value = value.trim(); + if value.is_empty() { + return None; + } + if let Some((addr_str, bits_str)) = value.split_once('/') { + let bits: u8 = bits_str.parse().ok()?; + if let Ok(addr) = addr_str.parse::() { + return match addr { + std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), bits.min(32))), + std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), bits.min(128))), + }; + } + } + if let Ok(addr) = value.parse::() { + return match addr { + std::net::IpAddr::V4(v4) => Some(Cidr::V4(u32::from_be_bytes(v4.octets()), 32)), + std::net::IpAddr::V6(v6) => Some(Cidr::V6(u128::from_be_bytes(v6.octets()), 128)), + }; + } + None +} + +pub fn split_host_port(target: &str) -> Option<(String, u16)> { + if let Some((host, port)) = target.rsplit_once(':') { + if host.starts_with('[') && host.ends_with(']') { + let host = host.trim_start_matches('[').trim_end_matches(']').to_string(); + let port = port.parse().ok()?; + return Some((host, port)); + } + if host.contains(':') { + return None; + } + let port = port.parse().ok()?; + return Some((host.to_string(), port)); + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_split_host_port() { + assert_eq!(split_host_port("example.com:443"), Some(("example.com".to_string(), 443))); + assert_eq!(split_host_port("127.0.0.1:80"), Some(("127.0.0.1".to_string(), 80))); + assert_eq!(split_host_port("[::1]:8080"), Some(("::1".to_string(), 8080))); + assert_eq!(split_host_port("noport"), None); + assert_eq!(split_host_port("::1:8080"), None); // ambiguous IPv6 without brackets + } + + #[test] + fn test_parse_cidr_v4() { + let cidr = parse_cidr("10.0.0.0/8").unwrap(); + assert!(cidr.contains(&"10.1.2.3".parse().unwrap())); + assert!(cidr.contains(&"10.255.255.255".parse().unwrap())); + assert!(!cidr.contains(&"11.0.0.1".parse().unwrap())); + } + + #[test] + fn test_parse_cidr_v4_exact() { + let cidr = parse_cidr("192.168.1.1").unwrap(); + assert!(cidr.contains(&"192.168.1.1".parse().unwrap())); + assert!(!cidr.contains(&"192.168.1.2".parse().unwrap())); + } + + #[test] + fn test_parse_cidr_v6() { + let cidr = parse_cidr("::1/128").unwrap(); + assert!(cidr.contains(&"::1".parse().unwrap())); + assert!(!cidr.contains(&"::2".parse().unwrap())); + } + + #[test] + fn test_parse_cidr_invalid() { + assert!(parse_cidr("").is_none()); + assert!(parse_cidr("not-an-ip/24").is_none()); + } + + #[test] + fn test_match_domain_rule() { + assert!(match_domain_rule("example.com", &[".example.com".to_string()])); + assert!(match_domain_rule("sub.example.com", &[".example.com".to_string()])); + assert!(!match_domain_rule("notexample.com", &[".example.com".to_string()])); + assert!(match_domain_rule("test.onion", &[".onion".to_string()])); + assert!(!match_domain_rule("onion.com", &[".onion".to_string()])); + } + + #[test] + fn test_match_domain_rule_exact() { + // Without dot prefix, the rule matches both exact and subdomains + // because the implementation treats "example.com" as a suffix match + assert!(match_domain_rule("example.com", &["example.com".to_string()])); + assert!(match_domain_rule("sub.example.com", &["example.com".to_string()])); + } + + #[test] + fn test_match_domain_rule_empty() { + assert!(!match_domain_rule("example.com", &[])); + } +} diff --git a/ostp-server/src/relay.rs b/ostp-server/src/relay.rs new file mode 100644 index 0000000..166d53c --- /dev/null +++ b/ostp-server/src/relay.rs @@ -0,0 +1,114 @@ +use anyhow::Result; +use bytes::Bytes; +use std::collections::HashMap; + +use ostp_core::relay::RelayMessage; +use tokio::io::AsyncReadExt; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; + +use crate::dispatcher::Dispatcher; +use crate::outbound::{self, OutboundConfig}; +use crate::{RemoteState, UiEvent}; + +pub async fn handle_relay_message( + _peer_addr: std::net::SocketAddr, + session_id: u32, + stream_id: u16, + payload: Bytes, + dispatcher: &mut Dispatcher, + socket: &UdpSocket, + remotes: &mut HashMap<(u32, u16), RemoteState>, + ui_event_tx: &mpsc::UnboundedSender, + stream_tx: mpsc::UnboundedSender<(u32, u16, Vec)>, + connect_tx: mpsc::UnboundedSender<(u32, u16, String, Result<(tokio::net::tcp::OwnedWriteHalf, mpsc::Sender<()>), String>)>, + outbound_cfg: Option, + debug: bool, +) -> Result<()> { + match RelayMessage::decode(&payload)? { + RelayMessage::Connect(target) => { + let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CONNECT start for [{session_id}:{stream_id}] -> {target}"))); + let target_clone = target.clone(); + let connect_tx_clone = connect_tx.clone(); + let stream_tx_clone = stream_tx.clone(); + let outbound_clone = outbound_cfg.clone(); + tokio::spawn(async move { + let stream_res = outbound::connect_target(&target_clone, outbound_clone.as_ref(), debug).await; + match stream_res { + Ok(stream) => { + let (mut reader, writer) = stream.into_split(); + let (cancel_tx, mut cancel_rx) = mpsc::channel::<()>(1); + tokio::spawn(async move { + let mut buf = [0_u8; 4096]; + loop { + tokio::select! { + _ = cancel_rx.recv() => break, + read_res = reader.read(&mut buf) => { + match read_res { + Ok(0) | Err(_) => { + let _ = stream_tx_clone.send((session_id, stream_id, Vec::new())); + break; + } + Ok(n) => { + if stream_tx_clone.send((session_id, stream_id, buf[..n].to_vec())).is_err() { + break; + } + } + } + } + } + } + }); + let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Ok((writer, cancel_tx)))); + } + Err(e) => { + let _ = connect_tx_clone.send((session_id, stream_id, target_clone, Err(e.to_string()))); + } + } + }); + } + RelayMessage::Data(data) => { + if let Some(remote) = remotes.get_mut(&(session_id, stream_id)) { + let _ = remote.data_tx.send(bytes::Bytes::from(data)); + } else { + let _ = ui_event_tx.send(UiEvent::Log(format!("Relay DATA for unknown stream [{session_id}:{stream_id}] ({})", data.len()))); + } + } + RelayMessage::KeepAlive => {} + RelayMessage::Close => { + if let Some(state) = remotes.remove(&(session_id, stream_id)) { + let _ = state.cancel_tx.try_send(()); + let _ = ui_event_tx.send(UiEvent::Log(format!("Relay CLOSE [{session_id}:{stream_id}]"))); + } + } + RelayMessage::ConnectOk => {} + RelayMessage::Error(msg) => { + let _ = ui_event_tx.send(UiEvent::Log(format!("Relay error from [{session_id}:{stream_id}]: {msg}"))); + } + RelayMessage::Ping(ts) => { + send_relay_to_stream(session_id, stream_id, RelayMessage::Pong(ts), dispatcher, socket, ui_event_tx).await?; + } + RelayMessage::Pong(_) => {} + } + Ok(()) +} + +pub async fn send_relay_to_stream( + session_id: u32, + stream_id: u16, + msg: RelayMessage, + dispatcher: &mut Dispatcher, + socket: &UdpSocket, + ui_event_tx: &mpsc::UnboundedSender, +) -> Result<()> { + let payload = Bytes::from(msg.encode()); + if let Some((frame, peer_addr)) = dispatcher.outbound_to_session(session_id, stream_id, payload)? { + let response_len = frame.len(); + let _ = socket.send_to(&frame, peer_addr).await?; + let _ = ui_event_tx.send(UiEvent::Tx { + peer: peer_addr.ip(), + bytes: response_len, + }); + } + Ok(()) +} diff --git a/ostp/Cargo.toml b/ostp/Cargo.toml index 5a69f79..bc5ff9e 100644 --- a/ostp/Cargo.toml +++ b/ostp/Cargo.toml @@ -16,3 +16,5 @@ json_comments = "0.2" base64 = "0.22" rand.workspace = true url = "2.5" +tracing.workspace = true +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/ostp/src/main.rs b/ostp/src/main.rs index 74a239a..9dfcafa 100644 --- a/ostp/src/main.rs +++ b/ostp/src/main.rs @@ -31,6 +31,10 @@ struct Args { #[arg(long)] links: bool, + /// Validate configuration file and exit + #[arg(long)] + check: bool, + /// Optional client connection share link (ostp://ACCESS_KEY@HOST:PORT) to run instantly url: Option, } @@ -133,11 +137,51 @@ impl UnifiedConfig { #[derive(Debug, Deserialize, Serialize)] struct ServerConfig { - listen: String, + listen: ListenConfig, access_keys: Vec, turn_server: Option, debug: Option, outbound: Option, + api: Option, + fallback: Option, +} + +/// Supports both single string "0.0.0.0:50000" and array ["0.0.0.0:50000", "[::]:50000"] +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(untagged)] +enum ListenConfig { + Single(String), + Multiple(Vec), +} + +impl ListenConfig { + fn addresses(&self) -> Vec { + match self { + ListenConfig::Single(s) => vec![s.clone()], + ListenConfig::Multiple(v) => v.clone(), + } + } + + fn primary(&self) -> String { + match self { + ListenConfig::Single(s) => s.clone(), + ListenConfig::Multiple(v) => v.first().cloned().unwrap_or_default(), + } + } +} + +#[derive(Debug, Deserialize, Serialize)] +struct ApiConfig { + enabled: Option, + bind: Option, + token: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +struct FallbackCfg { + enabled: Option, + listen: Option, + target: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -201,6 +245,17 @@ struct MuxConfig { #[tokio::main] async fn main() -> Result<()> { + // Initialize structured logging via tracing + // Default: info level; override with RUST_LOG env var (e.g. RUST_LOG=ostp_server=debug) + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")) + ) + .with_target(false) + .compact() + .init(); + let res = run_app().await; if let Err(e) = res { eprintln!(); @@ -308,6 +363,52 @@ async fn run_app() -> Result<()> { return run_client_directly(client_cfg).await; } + // Handle --check: validate config and exit + if args.check { + if !args.config.exists() { + anyhow::bail!("Configuration file {:?} not found.", args.config); + } + let content = fs::read_to_string(&args.config)?; + let mut stripped = json_comments::StripComments::new(content.as_bytes()); + match serde_json::from_reader::<_, UnifiedConfig>(&mut stripped) { + Ok(config) => { + config.validate()?; + match &config.mode { + AppMode::Server(s) => { + println!("[ostp] Config OK: server mode"); + println!(" Listen: {:?}", s.listen.primary()); + println!(" Access keys: {}", s.access_keys.len()); + if let Some(api) = &s.api { + println!(" API: {} (bind: {})", + if api.enabled.unwrap_or(false) { "enabled" } else { "disabled" }, + api.bind.as_deref().unwrap_or("127.0.0.1:9090")); + } + if let Some(outbound) = &s.outbound { + println!(" Outbound proxy: {} ({})", + if outbound.enabled { "enabled" } else { "disabled" }, + outbound.protocol); + } + if let Some(fb) = &s.fallback { + println!(" Fallback: {} ({} -> {})", + if fb.enabled.unwrap_or(false) { "enabled" } else { "disabled" }, + fb.listen.as_deref().unwrap_or("0.0.0.0:443"), + fb.target.as_deref().unwrap_or("127.0.0.1:8080")); + } + } + AppMode::Client(c) => { + println!("[ostp] Config OK: client mode"); + println!(" Server: {}", c.server); + println!(" Key: {}...", &c.access_key[..8.min(c.access_key.len())]); + } + } + } + Err(e) => { + anyhow::bail!("Config parse error: {}", e); + } + } + return Ok(()); + } + // Handle explicit configuration initialization if let Some(ref mode_str) = args.init { let is_server = mode_str == "server"; @@ -341,6 +442,22 @@ async fn run_app() -> Result<()> { }} ] }}, + + // Management REST API for third-party panels. + "api": {{ + "enabled": false, + "bind": "127.0.0.1:9090", + // Set a strong token for authentication. Leave empty to disable auth. + "token": "" + }}, + + // Fallback TCP proxy: unrecognized connections are proxied to a web server (anti-DPI). + "fallback": {{ + "enabled": false, + "listen": "0.0.0.0:443", + // Target web server (e.g., local nginx or caddy) + "target": "127.0.0.1:8080" + }}, "debug": false }}"#, key) } else { @@ -429,7 +546,7 @@ async fn run_app() -> Result<()> { if args.links { match config.mode { AppMode::Server(server_cfg) => { - let listen = server_cfg.listen.clone(); + let listen = server_cfg.listen.primary(); let parts: Vec<&str> = listen.split(':').collect(); let port = parts.get(1).unwrap_or(&"50000"); let host = if parts[0] == "0.0.0.0" { @@ -452,11 +569,11 @@ async fn run_app() -> Result<()> { match config.mode { AppMode::Server(server_cfg) => { - println!("[ostp] Starting server on {}", server_cfg.listen); + let listen_addrs = server_cfg.listen.addresses(); + println!("[ostp] Starting server on {:?}", listen_addrs); if let Some(turn) = server_cfg.turn_server { println!("[ostp] TURN relay enabled: {}", turn); } - // Temporarily pass control to the isolated server implementation let debug = server_cfg.debug.unwrap_or(false); let outbound = server_cfg.outbound.map(|o| ostp_server::OutboundConfig { enabled: o.enabled, @@ -474,7 +591,18 @@ async fn run_app() -> Result<()> { .collect(), default_action: parse_outbound_action(o.default_action), }); - ostp_server::run_server(server_cfg.listen, server_cfg.access_keys, outbound, debug).await?; + let api_config = server_cfg.api.map(|a| ostp_server::ApiConfig { + enabled: a.enabled.unwrap_or(false), + bind: a.bind.unwrap_or_else(|| "127.0.0.1:9090".to_string()), + token: a.token.unwrap_or_default(), + }); + let fallback_config = server_cfg.fallback.map(|f| ostp_server::FallbackConfig { + enabled: f.enabled.unwrap_or(false), + listen: f.listen.unwrap_or_else(|| "0.0.0.0:443".to_string()), + target: f.target.unwrap_or_else(|| "127.0.0.1:8080".to_string()), + }); + // Pass all listen addresses for multi-listener support + ostp_server::run_server(listen_addrs, server_cfg.access_keys, outbound, api_config, fallback_config, debug).await?; } AppMode::Client(client_cfg) => { run_client_directly(client_cfg).await?;