Refactor DNS client
This commit is contained in:
@@ -41,224 +41,7 @@ use {
|
||||
};
|
||||
|
||||
use crate::{app::dispatcher::Dispatcher, option, proxy::*, session::*};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct CacheEntry {
|
||||
pub ips: Vec<IpAddr>,
|
||||
pub deadline: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EchCacheEntry {
|
||||
pub ech_config_list: String,
|
||||
pub deadline: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct DohResolver {
|
||||
domain: String,
|
||||
bootstrap_ip: Option<IpAddr>,
|
||||
is_direct: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum Resolver {
|
||||
Server(SocketAddr, bool),
|
||||
DoH(DohResolver),
|
||||
System(bool),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ServerRuntimeStats {
|
||||
avg_latency_ms: f64,
|
||||
samples: u64,
|
||||
successes: u64,
|
||||
failures: u64,
|
||||
timeouts: u64,
|
||||
consecutive_slow: u32,
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ServerSelectorState {
|
||||
primary_server: Option<String>,
|
||||
stats: HashMap<String, ServerRuntimeStats>,
|
||||
last_reselect_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl fmt::Display for Resolver {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Server(addr, direct) => {
|
||||
if *direct {
|
||||
write!(f, "direct:{}", addr)
|
||||
} else {
|
||||
write!(f, "{}", addr)
|
||||
}
|
||||
}
|
||||
Self::DoH(doh) => {
|
||||
if doh.is_direct {
|
||||
write!(f, "direct:doh:{}", doh.domain)?;
|
||||
} else {
|
||||
write!(f, "doh:{}", doh.domain)?;
|
||||
}
|
||||
if let Some(ip) = doh.bootstrap_ip {
|
||||
write!(f, "@{}", ip)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::System(direct) => {
|
||||
if *direct {
|
||||
write!(f, "direct:system")
|
||||
} else {
|
||||
write!(f, "system")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerSelectorState {
|
||||
fn score_of(&self, server: &str) -> f64 {
|
||||
if let Some(stat) = self.stats.get(server) {
|
||||
let baseline = if stat.samples == 0 {
|
||||
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
|
||||
} else {
|
||||
stat.avg_latency_ms
|
||||
};
|
||||
baseline
|
||||
+ (stat.failures as f64 * 600.0)
|
||||
+ (stat.timeouts as f64 * 900.0)
|
||||
+ (stat.consecutive_failures as f64 * 1200.0)
|
||||
+ (stat.consecutive_slow as f64 * 300.0)
|
||||
} else {
|
||||
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
fn is_degraded(&self, server: &str) -> bool {
|
||||
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
if let Some(stat) = self.stats.get(server) {
|
||||
(stat.consecutive_failures as usize) >= switch_threshold
|
||||
|| (stat.consecutive_slow as usize) >= switch_threshold
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_candidates(&mut self, servers: &[&Resolver]) {
|
||||
for server in servers {
|
||||
self.stats.entry(server.to_string()).or_default();
|
||||
}
|
||||
}
|
||||
|
||||
fn select_primary_index(&mut self, servers: &[&Resolver]) -> usize {
|
||||
if servers.len() <= 1 {
|
||||
if let Some(server) = servers.first() {
|
||||
self.primary_server = Some(server.to_string());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
self.ensure_candidates(servers);
|
||||
let now = Instant::now();
|
||||
let reselect_interval =
|
||||
Duration::from_secs((*option::DNS_SERVER_RESELECT_INTERVAL_SECS).max(1));
|
||||
let should_reselect = self
|
||||
.last_reselect_at
|
||||
.map(|last| now.saturating_duration_since(last) >= reselect_interval)
|
||||
.unwrap_or(true);
|
||||
|
||||
let current_idx = self.primary_server.as_ref().and_then(|primary| {
|
||||
servers
|
||||
.iter()
|
||||
.position(|server| server.to_string() == *primary)
|
||||
});
|
||||
if let Some(idx) = current_idx {
|
||||
let current_key = servers[idx].to_string();
|
||||
if !should_reselect && !self.is_degraded(¤t_key) {
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
|
||||
let mut best_idx = 0usize;
|
||||
let mut best_score = f64::MAX;
|
||||
for (idx, server) in servers.iter().enumerate() {
|
||||
let score = self.score_of(&server.to_string());
|
||||
if score < best_score {
|
||||
best_score = score;
|
||||
best_idx = idx;
|
||||
}
|
||||
}
|
||||
self.primary_server = Some(servers[best_idx].to_string());
|
||||
self.last_reselect_at = Some(now);
|
||||
best_idx
|
||||
}
|
||||
|
||||
fn fallback_indices(&self, servers: &[&Resolver], preferred_idx: usize) -> Vec<usize> {
|
||||
let mut candidates: Vec<usize> = (0..servers.len())
|
||||
.filter(|idx| *idx != preferred_idx)
|
||||
.collect();
|
||||
candidates.sort_by(|a, b| {
|
||||
let sa = self.score_of(&servers[*a].to_string());
|
||||
let sb = self.score_of(&servers[*b].to_string());
|
||||
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
candidates
|
||||
}
|
||||
|
||||
fn mark_success(&mut self, server: &str, elapsed: Duration) {
|
||||
let stat = self.stats.entry(server.to_owned()).or_default();
|
||||
let elapsed_ms = elapsed.as_millis() as f64;
|
||||
stat.successes = stat.successes.saturating_add(1);
|
||||
stat.samples = stat.samples.saturating_add(1);
|
||||
if stat.samples == 1 {
|
||||
stat.avg_latency_ms = elapsed_ms;
|
||||
} else {
|
||||
stat.avg_latency_ms = stat.avg_latency_ms * 0.8 + elapsed_ms * 0.2;
|
||||
}
|
||||
let slow_threshold = (*option::DNS_SERVER_SLOW_RESPONSE_MS).max(1) as f64;
|
||||
if elapsed_ms >= slow_threshold {
|
||||
stat.consecutive_slow = stat.consecutive_slow.saturating_add(1);
|
||||
} else {
|
||||
stat.consecutive_slow = 0;
|
||||
}
|
||||
stat.consecutive_failures = 0;
|
||||
if self.primary_server.is_none() {
|
||||
self.primary_server = Some(server.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
fn mark_failure(&mut self, server: &str, is_timeout: bool) {
|
||||
let stat = self.stats.entry(server.to_owned()).or_default();
|
||||
stat.failures = stat.failures.saturating_add(1);
|
||||
if is_timeout {
|
||||
stat.timeouts = stat.timeouts.saturating_add(1);
|
||||
}
|
||||
stat.consecutive_failures = stat.consecutive_failures.saturating_add(1);
|
||||
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
if self.primary_server.as_deref() == Some(server)
|
||||
&& (stat.consecutive_failures as usize) >= switch_threshold
|
||||
{
|
||||
self.primary_server = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_primary(&mut self, server: &str) {
|
||||
self.primary_server = Some(server.to_owned());
|
||||
self.last_reselect_at = Some(Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DnsClient {
|
||||
dispatcher: Option<Weak<Dispatcher>>,
|
||||
servers: Vec<Resolver>,
|
||||
hosts: HashMap<String, Vec<IpAddr>>,
|
||||
ipv4_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
|
||||
ipv6_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
|
||||
ech_cache: Arc<TokioMutex<LruCache<String, EchCacheEntry>>>,
|
||||
ech_query_locks: Arc<TokioMutex<HashMap<String, Arc<TokioMutex<()>>>>>,
|
||||
selector_state: Arc<Mutex<ServerSelectorState>>,
|
||||
}
|
||||
include!("client/types.rs");
|
||||
|
||||
impl DnsClient {
|
||||
fn load_servers(dns: &crate::config::Dns) -> Result<Vec<Resolver>> {
|
||||
@@ -1855,211 +1638,4 @@ impl DnsClient {
|
||||
}
|
||||
|
||||
impl UdpConnector for DnsClient {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::{DnsClient, Resolver, ServerSelectorState};
|
||||
|
||||
fn new_client(servers: Vec<&str>) -> DnsClient {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = servers.into_iter().map(|s| s.to_string()).collect();
|
||||
DnsClient::new(&protobuf::MessageField::some(dns)).unwrap()
|
||||
}
|
||||
|
||||
fn collect_server_strings(client: &DnsClient, is_direct_outbound: bool) -> Vec<String> {
|
||||
client
|
||||
.collect_servers(is_direct_outbound)
|
||||
.into_iter()
|
||||
.map(|server| server.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_servers_supports_legacy_and_doh_with_ip() {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = vec![
|
||||
"1.1.1.1".to_string(),
|
||||
"direct:system".to_string(),
|
||||
"doh:example.com@9.9.9.9".to_string(),
|
||||
"direct:doh:example.com@8.8.8.8".to_string(),
|
||||
"doh:example.net".to_string(),
|
||||
];
|
||||
let servers = DnsClient::load_servers(&dns).unwrap();
|
||||
|
||||
match &servers[0] {
|
||||
Resolver::Server(addr, false) => assert_eq!(
|
||||
*addr,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
|
||||
),
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[1] {
|
||||
Resolver::System(true) => {}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[2] {
|
||||
Resolver::DoH(doh) => {
|
||||
assert_eq!(doh.domain, "example.com");
|
||||
assert_eq!(
|
||||
doh.bootstrap_ip,
|
||||
Some(IpAddr::V4(Ipv4Addr::new(9, 9, 9, 9)))
|
||||
);
|
||||
assert!(!doh.is_direct);
|
||||
}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[3] {
|
||||
Resolver::DoH(doh) => {
|
||||
assert_eq!(doh.domain, "example.com");
|
||||
assert_eq!(
|
||||
doh.bootstrap_ip,
|
||||
Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))
|
||||
);
|
||||
assert!(doh.is_direct);
|
||||
}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[4] {
|
||||
Resolver::DoH(doh) => {
|
||||
assert_eq!(doh.domain, "example.net");
|
||||
assert_eq!(doh.bootstrap_ip, None);
|
||||
assert!(!doh.is_direct);
|
||||
}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_servers_ignores_invalid_doh_value_if_any_valid_server_exists() {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = vec![
|
||||
"doh:@1.1.1.1".to_string(),
|
||||
"direct:doh:example.com@not-an-ip".to_string(),
|
||||
"doh:example.com#8.8.8.8".to_string(),
|
||||
"1.1.1.1".to_string(),
|
||||
];
|
||||
let servers = DnsClient::load_servers(&dns).unwrap();
|
||||
assert_eq!(servers.len(), 1);
|
||||
match &servers[0] {
|
||||
Resolver::Server(addr, false) => assert_eq!(
|
||||
*addr,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
|
||||
),
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_servers_rejects_when_all_servers_invalid() {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = vec![
|
||||
"doh:@1.1.1.1".to_string(),
|
||||
"direct:doh:example.com@not-an-ip".to_string(),
|
||||
"doh:example.com#8.8.8.8".to_string(),
|
||||
];
|
||||
let err = DnsClient::load_servers(&dns).unwrap_err();
|
||||
assert!(err.to_string().contains("no dns servers"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_servers_includes_direct_doh_for_direct_outbound() {
|
||||
let client = new_client(vec![
|
||||
"1.1.1.1",
|
||||
"doh:normal.example",
|
||||
"direct:doh:direct.example@8.8.8.8",
|
||||
]);
|
||||
let selected = collect_server_strings(&client, true);
|
||||
assert_eq!(selected, vec!["direct:doh:direct.example@8.8.8.8"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_servers_fallback_to_normal_keeps_non_direct_doh() {
|
||||
let client = new_client(vec!["doh:normal.example", "1.1.1.1", "system"]);
|
||||
let selected = collect_server_strings(&client, true);
|
||||
assert_eq!(
|
||||
selected,
|
||||
vec![
|
||||
"doh:normal.example".to_string(),
|
||||
"1.1.1.1:53".to_string(),
|
||||
"system".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_doh_http_body_supports_content_length() {
|
||||
let body = b"\x01\x02\x03\x04";
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\nContent-Length: {}\r\n\r\n",
|
||||
body.len()
|
||||
);
|
||||
let mut raw = response.into_bytes();
|
||||
raw.extend_from_slice(body);
|
||||
|
||||
let parsed = DnsClient::parse_doh_http_body(&raw).unwrap();
|
||||
assert_eq!(parsed, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_doh_http_body_supports_chunked() {
|
||||
let response =
|
||||
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nABCD\r\n2\r\nEF\r\n0\r\n\r\n";
|
||||
let parsed = DnsClient::parse_doh_http_body(response).unwrap();
|
||||
assert_eq!(parsed, b"ABCDEF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_doh_http_body_rejects_non_200() {
|
||||
let response = b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 3\r\n\r\nbad".to_vec();
|
||||
let err = DnsClient::parse_doh_http_body(&response).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("doh server returned http status 503"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selector_primary_switches_after_consecutive_failures() {
|
||||
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
|
||||
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
|
||||
let servers = vec![&s1, &s2];
|
||||
let mut selector = ServerSelectorState::default();
|
||||
let initial = selector.select_primary_index(&servers);
|
||||
assert_eq!(initial, 0);
|
||||
let key = s1.to_string();
|
||||
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
for _ in 0..threshold {
|
||||
selector.mark_failure(&key, true);
|
||||
}
|
||||
let selected = selector.select_primary_index(&servers);
|
||||
assert_eq!(selected, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selector_prefers_lower_latency_in_fallback_order() {
|
||||
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
|
||||
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
|
||||
let s3 = DnsClient::parse_server("9.9.9.9").unwrap();
|
||||
let servers = vec![&s1, &s2, &s3];
|
||||
let mut selector = ServerSelectorState::default();
|
||||
selector.mark_success(&s2.to_string(), Duration::from_millis(30));
|
||||
selector.mark_success(&s3.to_string(), Duration::from_millis(450));
|
||||
let order = selector.fallback_indices(&servers, 0);
|
||||
assert_eq!(order, vec![1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selector_marks_slow_server_as_degraded() {
|
||||
let server = DnsClient::parse_server("1.1.1.1").unwrap();
|
||||
let mut selector = ServerSelectorState::default();
|
||||
let key = server.to_string();
|
||||
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
let slow_elapsed = Duration::from_millis(*crate::option::DNS_SERVER_SLOW_RESPONSE_MS + 50);
|
||||
for _ in 0..threshold {
|
||||
selector.mark_success(&key, slow_elapsed);
|
||||
}
|
||||
assert!(selector.is_degraded(&key));
|
||||
}
|
||||
}
|
||||
include!("client/tests.rs");
|
||||
207
leaf/src/app/dns/client/tests.rs
Normal file
207
leaf/src/app/dns/client/tests.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||
use std::time::Duration;
|
||||
|
||||
use super::{DnsClient, Resolver, ServerSelectorState};
|
||||
|
||||
fn new_client(servers: Vec<&str>) -> DnsClient {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = servers.into_iter().map(|s| s.to_string()).collect();
|
||||
DnsClient::new(&protobuf::MessageField::some(dns)).unwrap()
|
||||
}
|
||||
|
||||
fn collect_server_strings(client: &DnsClient, is_direct_outbound: bool) -> Vec<String> {
|
||||
client
|
||||
.collect_servers(is_direct_outbound)
|
||||
.into_iter()
|
||||
.map(|server| server.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_servers_supports_legacy_and_doh_with_ip() {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = vec![
|
||||
"1.1.1.1".to_string(),
|
||||
"direct:system".to_string(),
|
||||
"doh:example.com@9.9.9.9".to_string(),
|
||||
"direct:doh:example.com@8.8.8.8".to_string(),
|
||||
"doh:example.net".to_string(),
|
||||
];
|
||||
let servers = DnsClient::load_servers(&dns).unwrap();
|
||||
|
||||
match &servers[0] {
|
||||
Resolver::Server(addr, false) => assert_eq!(
|
||||
*addr,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
|
||||
),
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[1] {
|
||||
Resolver::System(true) => {}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[2] {
|
||||
Resolver::DoH(doh) => {
|
||||
assert_eq!(doh.domain, "example.com");
|
||||
assert_eq!(
|
||||
doh.bootstrap_ip,
|
||||
Some(IpAddr::V4(Ipv4Addr::new(9, 9, 9, 9)))
|
||||
);
|
||||
assert!(!doh.is_direct);
|
||||
}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[3] {
|
||||
Resolver::DoH(doh) => {
|
||||
assert_eq!(doh.domain, "example.com");
|
||||
assert_eq!(
|
||||
doh.bootstrap_ip,
|
||||
Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))
|
||||
);
|
||||
assert!(doh.is_direct);
|
||||
}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
match &servers[4] {
|
||||
Resolver::DoH(doh) => {
|
||||
assert_eq!(doh.domain, "example.net");
|
||||
assert_eq!(doh.bootstrap_ip, None);
|
||||
assert!(!doh.is_direct);
|
||||
}
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_servers_ignores_invalid_doh_value_if_any_valid_server_exists() {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = vec![
|
||||
"doh:@1.1.1.1".to_string(),
|
||||
"direct:doh:example.com@not-an-ip".to_string(),
|
||||
"doh:example.com#8.8.8.8".to_string(),
|
||||
"1.1.1.1".to_string(),
|
||||
];
|
||||
let servers = DnsClient::load_servers(&dns).unwrap();
|
||||
assert_eq!(servers.len(), 1);
|
||||
match &servers[0] {
|
||||
Resolver::Server(addr, false) => assert_eq!(
|
||||
*addr,
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
|
||||
),
|
||||
_ => panic!("unexpected resolver"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_servers_rejects_when_all_servers_invalid() {
|
||||
let mut dns = crate::config::Dns::new();
|
||||
dns.servers = vec![
|
||||
"doh:@1.1.1.1".to_string(),
|
||||
"direct:doh:example.com@not-an-ip".to_string(),
|
||||
"doh:example.com#8.8.8.8".to_string(),
|
||||
];
|
||||
let err = DnsClient::load_servers(&dns).unwrap_err();
|
||||
assert!(err.to_string().contains("no dns servers"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_servers_includes_direct_doh_for_direct_outbound() {
|
||||
let client = new_client(vec![
|
||||
"1.1.1.1",
|
||||
"doh:normal.example",
|
||||
"direct:doh:direct.example@8.8.8.8",
|
||||
]);
|
||||
let selected = collect_server_strings(&client, true);
|
||||
assert_eq!(selected, vec!["direct:doh:direct.example@8.8.8.8"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_servers_fallback_to_normal_keeps_non_direct_doh() {
|
||||
let client = new_client(vec!["doh:normal.example", "1.1.1.1", "system"]);
|
||||
let selected = collect_server_strings(&client, true);
|
||||
assert_eq!(
|
||||
selected,
|
||||
vec![
|
||||
"doh:normal.example".to_string(),
|
||||
"1.1.1.1:53".to_string(),
|
||||
"system".to_string()
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_doh_http_body_supports_content_length() {
|
||||
let body = b"\x01\x02\x03\x04";
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\nContent-Length: {}\r\n\r\n",
|
||||
body.len()
|
||||
);
|
||||
let mut raw = response.into_bytes();
|
||||
raw.extend_from_slice(body);
|
||||
|
||||
let parsed = DnsClient::parse_doh_http_body(&raw).unwrap();
|
||||
assert_eq!(parsed, body);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_doh_http_body_supports_chunked() {
|
||||
let response =
|
||||
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nABCD\r\n2\r\nEF\r\n0\r\n\r\n";
|
||||
let parsed = DnsClient::parse_doh_http_body(response).unwrap();
|
||||
assert_eq!(parsed, b"ABCDEF");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_doh_http_body_rejects_non_200() {
|
||||
let response = b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 3\r\n\r\nbad".to_vec();
|
||||
let err = DnsClient::parse_doh_http_body(&response).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("doh server returned http status 503"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selector_primary_switches_after_consecutive_failures() {
|
||||
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
|
||||
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
|
||||
let servers = vec![&s1, &s2];
|
||||
let mut selector = ServerSelectorState::default();
|
||||
let initial = selector.select_primary_index(&servers);
|
||||
assert_eq!(initial, 0);
|
||||
let key = s1.to_string();
|
||||
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
for _ in 0..threshold {
|
||||
selector.mark_failure(&key, true);
|
||||
}
|
||||
let selected = selector.select_primary_index(&servers);
|
||||
assert_eq!(selected, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selector_prefers_lower_latency_in_fallback_order() {
|
||||
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
|
||||
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
|
||||
let s3 = DnsClient::parse_server("9.9.9.9").unwrap();
|
||||
let servers = vec![&s1, &s2, &s3];
|
||||
let mut selector = ServerSelectorState::default();
|
||||
selector.mark_success(&s2.to_string(), Duration::from_millis(30));
|
||||
selector.mark_success(&s3.to_string(), Duration::from_millis(450));
|
||||
let order = selector.fallback_indices(&servers, 0);
|
||||
assert_eq!(order, vec![1, 2]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn selector_marks_slow_server_as_degraded() {
|
||||
let server = DnsClient::parse_server("1.1.1.1").unwrap();
|
||||
let mut selector = ServerSelectorState::default();
|
||||
let key = server.to_string();
|
||||
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
let slow_elapsed = Duration::from_millis(*crate::option::DNS_SERVER_SLOW_RESPONSE_MS + 50);
|
||||
for _ in 0..threshold {
|
||||
selector.mark_success(&key, slow_elapsed);
|
||||
}
|
||||
assert!(selector.is_degraded(&key));
|
||||
}
|
||||
}
|
||||
217
leaf/src/app/dns/client/types.rs
Normal file
217
leaf/src/app/dns/client/types.rs
Normal file
@@ -0,0 +1,217 @@
|
||||
#[derive(Clone, Debug)]
|
||||
struct CacheEntry {
|
||||
pub ips: Vec<IpAddr>,
|
||||
pub deadline: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EchCacheEntry {
|
||||
pub ech_config_list: String,
|
||||
pub deadline: Instant,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct DohResolver {
|
||||
domain: String,
|
||||
bootstrap_ip: Option<IpAddr>,
|
||||
is_direct: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum Resolver {
|
||||
Server(SocketAddr, bool),
|
||||
DoH(DohResolver),
|
||||
System(bool),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ServerRuntimeStats {
|
||||
avg_latency_ms: f64,
|
||||
samples: u64,
|
||||
successes: u64,
|
||||
failures: u64,
|
||||
timeouts: u64,
|
||||
consecutive_slow: u32,
|
||||
consecutive_failures: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
struct ServerSelectorState {
|
||||
primary_server: Option<String>,
|
||||
stats: HashMap<String, ServerRuntimeStats>,
|
||||
last_reselect_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl fmt::Display for Resolver {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Server(addr, direct) => {
|
||||
if *direct {
|
||||
write!(f, "direct:{}", addr)
|
||||
} else {
|
||||
write!(f, "{}", addr)
|
||||
}
|
||||
}
|
||||
Self::DoH(doh) => {
|
||||
if doh.is_direct {
|
||||
write!(f, "direct:doh:{}", doh.domain)?;
|
||||
} else {
|
||||
write!(f, "doh:{}", doh.domain)?;
|
||||
}
|
||||
if let Some(ip) = doh.bootstrap_ip {
|
||||
write!(f, "@{}", ip)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
Self::System(direct) => {
|
||||
if *direct {
|
||||
write!(f, "direct:system")
|
||||
} else {
|
||||
write!(f, "system")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ServerSelectorState {
|
||||
fn score_of(&self, server: &str) -> f64 {
|
||||
if let Some(stat) = self.stats.get(server) {
|
||||
let baseline = if stat.samples == 0 {
|
||||
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
|
||||
} else {
|
||||
stat.avg_latency_ms
|
||||
};
|
||||
baseline
|
||||
+ (stat.failures as f64 * 600.0)
|
||||
+ (stat.timeouts as f64 * 900.0)
|
||||
+ (stat.consecutive_failures as f64 * 1200.0)
|
||||
+ (stat.consecutive_slow as f64 * 300.0)
|
||||
} else {
|
||||
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
fn is_degraded(&self, server: &str) -> bool {
|
||||
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
if let Some(stat) = self.stats.get(server) {
|
||||
(stat.consecutive_failures as usize) >= switch_threshold
|
||||
|| (stat.consecutive_slow as usize) >= switch_threshold
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_candidates(&mut self, servers: &[&Resolver]) {
|
||||
for server in servers {
|
||||
self.stats.entry(server.to_string()).or_default();
|
||||
}
|
||||
}
|
||||
|
||||
fn select_primary_index(&mut self, servers: &[&Resolver]) -> usize {
|
||||
if servers.len() <= 1 {
|
||||
if let Some(server) = servers.first() {
|
||||
self.primary_server = Some(server.to_string());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
self.ensure_candidates(servers);
|
||||
let now = Instant::now();
|
||||
let reselect_interval =
|
||||
Duration::from_secs((*option::DNS_SERVER_RESELECT_INTERVAL_SECS).max(1));
|
||||
let should_reselect = self
|
||||
.last_reselect_at
|
||||
.map(|last| now.saturating_duration_since(last) >= reselect_interval)
|
||||
.unwrap_or(true);
|
||||
|
||||
let current_idx = self.primary_server.as_ref().and_then(|primary| {
|
||||
servers
|
||||
.iter()
|
||||
.position(|server| server.to_string() == *primary)
|
||||
});
|
||||
if let Some(idx) = current_idx {
|
||||
let current_key = servers[idx].to_string();
|
||||
if !should_reselect && !self.is_degraded(¤t_key) {
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
|
||||
let mut best_idx = 0usize;
|
||||
let mut best_score = f64::MAX;
|
||||
for (idx, server) in servers.iter().enumerate() {
|
||||
let score = self.score_of(&server.to_string());
|
||||
if score < best_score {
|
||||
best_score = score;
|
||||
best_idx = idx;
|
||||
}
|
||||
}
|
||||
self.primary_server = Some(servers[best_idx].to_string());
|
||||
self.last_reselect_at = Some(now);
|
||||
best_idx
|
||||
}
|
||||
|
||||
fn fallback_indices(&self, servers: &[&Resolver], preferred_idx: usize) -> Vec<usize> {
|
||||
let mut candidates: Vec<usize> = (0..servers.len())
|
||||
.filter(|idx| *idx != preferred_idx)
|
||||
.collect();
|
||||
candidates.sort_by(|a, b| {
|
||||
let sa = self.score_of(&servers[*a].to_string());
|
||||
let sb = self.score_of(&servers[*b].to_string());
|
||||
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
candidates
|
||||
}
|
||||
|
||||
fn mark_success(&mut self, server: &str, elapsed: Duration) {
|
||||
let stat = self.stats.entry(server.to_owned()).or_default();
|
||||
let elapsed_ms = elapsed.as_millis() as f64;
|
||||
stat.successes = stat.successes.saturating_add(1);
|
||||
stat.samples = stat.samples.saturating_add(1);
|
||||
if stat.samples == 1 {
|
||||
stat.avg_latency_ms = elapsed_ms;
|
||||
} else {
|
||||
stat.avg_latency_ms = stat.avg_latency_ms * 0.8 + elapsed_ms * 0.2;
|
||||
}
|
||||
let slow_threshold = (*option::DNS_SERVER_SLOW_RESPONSE_MS).max(1) as f64;
|
||||
if elapsed_ms >= slow_threshold {
|
||||
stat.consecutive_slow = stat.consecutive_slow.saturating_add(1);
|
||||
} else {
|
||||
stat.consecutive_slow = 0;
|
||||
}
|
||||
stat.consecutive_failures = 0;
|
||||
if self.primary_server.is_none() {
|
||||
self.primary_server = Some(server.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
fn mark_failure(&mut self, server: &str, is_timeout: bool) {
|
||||
let stat = self.stats.entry(server.to_owned()).or_default();
|
||||
stat.failures = stat.failures.saturating_add(1);
|
||||
if is_timeout {
|
||||
stat.timeouts = stat.timeouts.saturating_add(1);
|
||||
}
|
||||
stat.consecutive_failures = stat.consecutive_failures.saturating_add(1);
|
||||
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
|
||||
if self.primary_server.as_deref() == Some(server)
|
||||
&& (stat.consecutive_failures as usize) >= switch_threshold
|
||||
{
|
||||
self.primary_server = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn set_primary(&mut self, server: &str) {
|
||||
self.primary_server = Some(server.to_owned());
|
||||
self.last_reselect_at = Some(Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DnsClient {
|
||||
dispatcher: Option<Weak<Dispatcher>>,
|
||||
servers: Vec<Resolver>,
|
||||
hosts: HashMap<String, Vec<IpAddr>>,
|
||||
ipv4_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
|
||||
ipv6_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
|
||||
ech_cache: Arc<TokioMutex<LruCache<String, EchCacheEntry>>>,
|
||||
ech_query_locks: Arc<TokioMutex<HashMap<String, Arc<TokioMutex<()>>>>>,
|
||||
selector_state: Arc<Mutex<ServerSelectorState>>,
|
||||
}
|
||||
3
leaf/src/app/dns/mod.rs
Normal file
3
leaf/src/app/dns/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod client;
|
||||
|
||||
pub use client::*;
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
pub mod dispatcher;
|
||||
pub mod dns_client;
|
||||
pub mod dns;
|
||||
pub mod healthcheck;
|
||||
pub mod inbound;
|
||||
pub mod logger;
|
||||
@@ -17,6 +17,10 @@ pub mod api;
|
||||
|
||||
pub mod fake_dns;
|
||||
|
||||
pub type SyncDnsClient = Arc<RwLock<dns_client::DnsClient>>;
|
||||
pub mod dns_client {
|
||||
pub use super::dns::*;
|
||||
}
|
||||
|
||||
pub type SyncDnsClient = Arc<RwLock<dns::DnsClient>>;
|
||||
|
||||
pub type SyncStatManager = Arc<RwLock<stat_manager::StatManager>>;
|
||||
|
||||
@@ -18,7 +18,7 @@ use notify::{
|
||||
};
|
||||
|
||||
use app::{
|
||||
dispatcher::Dispatcher, dns_client::DnsClient, inbound::manager::InboundManager,
|
||||
dispatcher::Dispatcher, dns::DnsClient, inbound::manager::InboundManager,
|
||||
nat_manager::NatManager, outbound::manager::OutboundManager, router::Router,
|
||||
};
|
||||
|
||||
|
||||
@@ -653,7 +653,7 @@ mod tests {
|
||||
use protobuf::MessageField;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::app::{dns_client::DnsClient, SyncDnsClient};
|
||||
use crate::app::{dns::DnsClient, SyncDnsClient};
|
||||
#[cfg(feature = "rustls-tls-aws-lc")]
|
||||
use crate::session::Session;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ use tokio::sync::RwLock;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use crate::{
|
||||
app::{dns_client::DnsClient, outbound::manager::OutboundManager, SyncDnsClient},
|
||||
app::{dns::DnsClient, outbound::manager::OutboundManager, SyncDnsClient},
|
||||
config::Config,
|
||||
proxy::*,
|
||||
session::*,
|
||||
|
||||
Reference in New Issue
Block a user