mirror of https://github.com/ospab/ostp.git
197 lines
6.6 KiB
Rust
197 lines
6.6 KiB
Rust
use std::collections::{HashMap, HashSet, VecDeque};
|
|
use std::sync::Arc;
|
|
use tokio::sync::{RwLock, Mutex};
|
|
use simple_dns::{Packet, rdata::RData, ResourceRecord, CLASS, TYPE, QTYPE};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct DnsConfig {
|
|
pub enabled: bool,
|
|
pub doh_upstream: String,
|
|
pub adblock_urls: Vec<String>,
|
|
pub custom_domains: HashMap<String, String>,
|
|
}
|
|
|
|
impl Default for DnsConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
enabled: false,
|
|
doh_upstream: "https://cloudflare-dns.com/dns-query".to_string(),
|
|
adblock_urls: vec![],
|
|
custom_domains: HashMap::new(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
pub struct DnsQueryLog {
|
|
pub timestamp: u64,
|
|
pub domain: String,
|
|
pub client_ip: String,
|
|
pub blocked: bool,
|
|
}
|
|
|
|
pub struct DnsServer {
|
|
pub config: RwLock<DnsConfig>,
|
|
adblock_trie: RwLock<HashSet<String>>, // Simplified to HashSet for now, or maybe a suffix tree
|
|
query_log: Mutex<VecDeque<DnsQueryLog>>,
|
|
reqwest_client: reqwest::Client,
|
|
}
|
|
|
|
impl DnsServer {
|
|
pub fn new(config: DnsConfig) -> Arc<Self> {
|
|
let server = Arc::new(Self {
|
|
config: RwLock::new(config.clone()),
|
|
adblock_trie: RwLock::new(HashSet::new()),
|
|
query_log: Mutex::new(VecDeque::with_capacity(1000)),
|
|
reqwest_client: reqwest::Client::builder()
|
|
.build()
|
|
.unwrap_or_default(),
|
|
});
|
|
|
|
// Spawn a background task to download blocklists
|
|
if config.enabled && !config.adblock_urls.is_empty() {
|
|
let server_clone = server.clone();
|
|
tokio::spawn(async move {
|
|
server_clone.update_blocklists().await;
|
|
});
|
|
}
|
|
|
|
server
|
|
}
|
|
|
|
pub async fn update_blocklists(&self) {
|
|
let urls = {
|
|
let cfg = self.config.read().await;
|
|
cfg.adblock_urls.clone()
|
|
};
|
|
|
|
let mut new_blocked = HashSet::new();
|
|
|
|
for url in urls {
|
|
if let Ok(resp) = self.reqwest_client.get(&url).send().await {
|
|
if let Ok(text) = resp.text().await {
|
|
for line in text.lines() {
|
|
let line = line.trim();
|
|
if line.is_empty() || line.starts_with('#') {
|
|
continue;
|
|
}
|
|
// Support standard hosts format: "0.0.0.0 ads.google.com" or just "ads.google.com"
|
|
let parts: Vec<&str> = line.split_whitespace().collect();
|
|
let domain = if parts.len() >= 2 && (parts[0] == "0.0.0.0" || parts[0] == "127.0.0.1") {
|
|
parts[1]
|
|
} else {
|
|
parts[0]
|
|
};
|
|
new_blocked.insert(domain.to_lowercase());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
tracing::info!("Loaded {} domains into AdBlock engine", new_blocked.len());
|
|
*self.adblock_trie.write().await = new_blocked;
|
|
}
|
|
|
|
pub async fn resolve(&self, payload: &[u8], client_ip: std::net::IpAddr) -> Option<Vec<u8>> {
|
|
let cfg = self.config.read().await;
|
|
if !cfg.enabled {
|
|
return None; // If DNS is disabled, fallback to standard UDP proxying
|
|
}
|
|
|
|
// Parse DNS packet
|
|
let packet = match Packet::parse(payload) {
|
|
Ok(p) => p,
|
|
Err(_) => return None,
|
|
};
|
|
|
|
if packet.questions.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let question = &packet.questions[0];
|
|
let qname = question.qname.to_string().to_lowercase();
|
|
|
|
// Check Custom Domains
|
|
if let Some(ip_str) = cfg.custom_domains.get(&qname) {
|
|
if let Ok(ip) = ip_str.parse::<std::net::Ipv4Addr>() {
|
|
if question.qtype == QTYPE::TYPE(TYPE::A) {
|
|
let mut response = Packet::new_reply(packet.id());
|
|
response.questions.push(question.clone());
|
|
response.answers.push(ResourceRecord::new(
|
|
question.qname.clone(),
|
|
CLASS::IN,
|
|
60,
|
|
RData::A(ip.into()),
|
|
));
|
|
self.log_query(qname, client_ip.to_string(), false).await;
|
|
return response.build_bytes_vec().ok();
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check AdBlock (Suffix matching not implemented in this simple hashset, for full pi-hole we need suffix match)
|
|
// Let's do a simple suffix check
|
|
let blocked = {
|
|
let blocked_domains = self.adblock_trie.read().await;
|
|
let mut parts: Vec<&str> = qname.split('.').collect();
|
|
let mut is_blocked = false;
|
|
while !parts.is_empty() {
|
|
let suffix = parts.join(".");
|
|
if blocked_domains.contains(&suffix) {
|
|
is_blocked = true;
|
|
break;
|
|
}
|
|
parts.remove(0);
|
|
}
|
|
is_blocked
|
|
};
|
|
|
|
if blocked {
|
|
let mut response = Packet::new_reply(packet.id());
|
|
response.questions.push(question.clone());
|
|
self.log_query(qname, client_ip.to_string(), true).await;
|
|
return response.build_bytes_vec().ok();
|
|
}
|
|
|
|
// Forward to DoH
|
|
let doh_url = cfg.doh_upstream.clone();
|
|
drop(cfg); // Release config lock before making network request
|
|
|
|
if let Ok(resp) = self.reqwest_client.post(&doh_url)
|
|
.header("Content-Type", "application/dns-message")
|
|
.header("Accept", "application/dns-message")
|
|
.body(payload.to_vec())
|
|
.send()
|
|
.await
|
|
{
|
|
if resp.status().is_success() {
|
|
if let Ok(bytes) = resp.bytes().await {
|
|
self.log_query(qname, client_ip.to_string(), false).await;
|
|
return Some(bytes.to_vec());
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
async fn log_query(&self, domain: String, client_ip: String, blocked: bool) {
|
|
let mut log = self.query_log.lock().await;
|
|
if log.len() >= 1000 {
|
|
log.pop_front();
|
|
}
|
|
log.push_back(DnsQueryLog {
|
|
timestamp: std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(),
|
|
domain,
|
|
client_ip,
|
|
blocked,
|
|
});
|
|
}
|
|
|
|
pub async fn get_queries(&self) -> Vec<DnsQueryLog> {
|
|
let log = self.query_log.lock().await;
|
|
log.iter().cloned().collect()
|
|
}
|
|
}
|