Refactor DNS client

This commit is contained in:
eric
2026-03-15 17:53:14 +08:00
parent 26d17f9da7
commit 291af8b9f0
8 changed files with 438 additions and 431 deletions

View File

@@ -41,224 +41,7 @@ use {
};
use crate::{app::dispatcher::Dispatcher, option, proxy::*, session::*};
#[derive(Clone, Debug)]
struct CacheEntry {
pub ips: Vec<IpAddr>,
pub deadline: Instant,
}
#[derive(Clone, Debug)]
pub struct EchCacheEntry {
pub ech_config_list: String,
pub deadline: Instant,
}
#[derive(Clone, Debug)]
struct DohResolver {
domain: String,
bootstrap_ip: Option<IpAddr>,
is_direct: bool,
}
#[derive(Clone, Debug)]
enum Resolver {
Server(SocketAddr, bool),
DoH(DohResolver),
System(bool),
}
#[derive(Clone, Debug, Default)]
struct ServerRuntimeStats {
avg_latency_ms: f64,
samples: u64,
successes: u64,
failures: u64,
timeouts: u64,
consecutive_slow: u32,
consecutive_failures: u32,
}
#[derive(Clone, Debug, Default)]
struct ServerSelectorState {
primary_server: Option<String>,
stats: HashMap<String, ServerRuntimeStats>,
last_reselect_at: Option<Instant>,
}
impl fmt::Display for Resolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Server(addr, direct) => {
if *direct {
write!(f, "direct:{}", addr)
} else {
write!(f, "{}", addr)
}
}
Self::DoH(doh) => {
if doh.is_direct {
write!(f, "direct:doh:{}", doh.domain)?;
} else {
write!(f, "doh:{}", doh.domain)?;
}
if let Some(ip) = doh.bootstrap_ip {
write!(f, "@{}", ip)?;
}
Ok(())
}
Self::System(direct) => {
if *direct {
write!(f, "direct:system")
} else {
write!(f, "system")
}
}
}
}
}
impl ServerSelectorState {
fn score_of(&self, server: &str) -> f64 {
if let Some(stat) = self.stats.get(server) {
let baseline = if stat.samples == 0 {
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
} else {
stat.avg_latency_ms
};
baseline
+ (stat.failures as f64 * 600.0)
+ (stat.timeouts as f64 * 900.0)
+ (stat.consecutive_failures as f64 * 1200.0)
+ (stat.consecutive_slow as f64 * 300.0)
} else {
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
}
}
fn is_degraded(&self, server: &str) -> bool {
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
if let Some(stat) = self.stats.get(server) {
(stat.consecutive_failures as usize) >= switch_threshold
|| (stat.consecutive_slow as usize) >= switch_threshold
} else {
false
}
}
fn ensure_candidates(&mut self, servers: &[&Resolver]) {
for server in servers {
self.stats.entry(server.to_string()).or_default();
}
}
fn select_primary_index(&mut self, servers: &[&Resolver]) -> usize {
if servers.len() <= 1 {
if let Some(server) = servers.first() {
self.primary_server = Some(server.to_string());
}
return 0;
}
self.ensure_candidates(servers);
let now = Instant::now();
let reselect_interval =
Duration::from_secs((*option::DNS_SERVER_RESELECT_INTERVAL_SECS).max(1));
let should_reselect = self
.last_reselect_at
.map(|last| now.saturating_duration_since(last) >= reselect_interval)
.unwrap_or(true);
let current_idx = self.primary_server.as_ref().and_then(|primary| {
servers
.iter()
.position(|server| server.to_string() == *primary)
});
if let Some(idx) = current_idx {
let current_key = servers[idx].to_string();
if !should_reselect && !self.is_degraded(&current_key) {
return idx;
}
}
let mut best_idx = 0usize;
let mut best_score = f64::MAX;
for (idx, server) in servers.iter().enumerate() {
let score = self.score_of(&server.to_string());
if score < best_score {
best_score = score;
best_idx = idx;
}
}
self.primary_server = Some(servers[best_idx].to_string());
self.last_reselect_at = Some(now);
best_idx
}
fn fallback_indices(&self, servers: &[&Resolver], preferred_idx: usize) -> Vec<usize> {
let mut candidates: Vec<usize> = (0..servers.len())
.filter(|idx| *idx != preferred_idx)
.collect();
candidates.sort_by(|a, b| {
let sa = self.score_of(&servers[*a].to_string());
let sb = self.score_of(&servers[*b].to_string());
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
});
candidates
}
fn mark_success(&mut self, server: &str, elapsed: Duration) {
let stat = self.stats.entry(server.to_owned()).or_default();
let elapsed_ms = elapsed.as_millis() as f64;
stat.successes = stat.successes.saturating_add(1);
stat.samples = stat.samples.saturating_add(1);
if stat.samples == 1 {
stat.avg_latency_ms = elapsed_ms;
} else {
stat.avg_latency_ms = stat.avg_latency_ms * 0.8 + elapsed_ms * 0.2;
}
let slow_threshold = (*option::DNS_SERVER_SLOW_RESPONSE_MS).max(1) as f64;
if elapsed_ms >= slow_threshold {
stat.consecutive_slow = stat.consecutive_slow.saturating_add(1);
} else {
stat.consecutive_slow = 0;
}
stat.consecutive_failures = 0;
if self.primary_server.is_none() {
self.primary_server = Some(server.to_owned());
}
}
fn mark_failure(&mut self, server: &str, is_timeout: bool) {
let stat = self.stats.entry(server.to_owned()).or_default();
stat.failures = stat.failures.saturating_add(1);
if is_timeout {
stat.timeouts = stat.timeouts.saturating_add(1);
}
stat.consecutive_failures = stat.consecutive_failures.saturating_add(1);
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
if self.primary_server.as_deref() == Some(server)
&& (stat.consecutive_failures as usize) >= switch_threshold
{
self.primary_server = None;
}
}
fn set_primary(&mut self, server: &str) {
self.primary_server = Some(server.to_owned());
self.last_reselect_at = Some(Instant::now());
}
}
pub struct DnsClient {
dispatcher: Option<Weak<Dispatcher>>,
servers: Vec<Resolver>,
hosts: HashMap<String, Vec<IpAddr>>,
ipv4_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
ipv6_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
ech_cache: Arc<TokioMutex<LruCache<String, EchCacheEntry>>>,
ech_query_locks: Arc<TokioMutex<HashMap<String, Arc<TokioMutex<()>>>>>,
selector_state: Arc<Mutex<ServerSelectorState>>,
}
include!("client/types.rs");
impl DnsClient {
fn load_servers(dns: &crate::config::Dns) -> Result<Vec<Resolver>> {
@@ -1855,211 +1638,4 @@ impl DnsClient {
}
impl UdpConnector for DnsClient {}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use super::{DnsClient, Resolver, ServerSelectorState};
fn new_client(servers: Vec<&str>) -> DnsClient {
let mut dns = crate::config::Dns::new();
dns.servers = servers.into_iter().map(|s| s.to_string()).collect();
DnsClient::new(&protobuf::MessageField::some(dns)).unwrap()
}
fn collect_server_strings(client: &DnsClient, is_direct_outbound: bool) -> Vec<String> {
client
.collect_servers(is_direct_outbound)
.into_iter()
.map(|server| server.to_string())
.collect()
}
#[test]
fn load_servers_supports_legacy_and_doh_with_ip() {
let mut dns = crate::config::Dns::new();
dns.servers = vec![
"1.1.1.1".to_string(),
"direct:system".to_string(),
"doh:example.com@9.9.9.9".to_string(),
"direct:doh:example.com@8.8.8.8".to_string(),
"doh:example.net".to_string(),
];
let servers = DnsClient::load_servers(&dns).unwrap();
match &servers[0] {
Resolver::Server(addr, false) => assert_eq!(
*addr,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
),
_ => panic!("unexpected resolver"),
}
match &servers[1] {
Resolver::System(true) => {}
_ => panic!("unexpected resolver"),
}
match &servers[2] {
Resolver::DoH(doh) => {
assert_eq!(doh.domain, "example.com");
assert_eq!(
doh.bootstrap_ip,
Some(IpAddr::V4(Ipv4Addr::new(9, 9, 9, 9)))
);
assert!(!doh.is_direct);
}
_ => panic!("unexpected resolver"),
}
match &servers[3] {
Resolver::DoH(doh) => {
assert_eq!(doh.domain, "example.com");
assert_eq!(
doh.bootstrap_ip,
Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))
);
assert!(doh.is_direct);
}
_ => panic!("unexpected resolver"),
}
match &servers[4] {
Resolver::DoH(doh) => {
assert_eq!(doh.domain, "example.net");
assert_eq!(doh.bootstrap_ip, None);
assert!(!doh.is_direct);
}
_ => panic!("unexpected resolver"),
}
}
#[test]
fn load_servers_ignores_invalid_doh_value_if_any_valid_server_exists() {
let mut dns = crate::config::Dns::new();
dns.servers = vec![
"doh:@1.1.1.1".to_string(),
"direct:doh:example.com@not-an-ip".to_string(),
"doh:example.com#8.8.8.8".to_string(),
"1.1.1.1".to_string(),
];
let servers = DnsClient::load_servers(&dns).unwrap();
assert_eq!(servers.len(), 1);
match &servers[0] {
Resolver::Server(addr, false) => assert_eq!(
*addr,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
),
_ => panic!("unexpected resolver"),
}
}
#[test]
fn load_servers_rejects_when_all_servers_invalid() {
let mut dns = crate::config::Dns::new();
dns.servers = vec![
"doh:@1.1.1.1".to_string(),
"direct:doh:example.com@not-an-ip".to_string(),
"doh:example.com#8.8.8.8".to_string(),
];
let err = DnsClient::load_servers(&dns).unwrap_err();
assert!(err.to_string().contains("no dns servers"));
}
#[test]
fn collect_servers_includes_direct_doh_for_direct_outbound() {
let client = new_client(vec![
"1.1.1.1",
"doh:normal.example",
"direct:doh:direct.example@8.8.8.8",
]);
let selected = collect_server_strings(&client, true);
assert_eq!(selected, vec!["direct:doh:direct.example@8.8.8.8"]);
}
#[test]
fn collect_servers_fallback_to_normal_keeps_non_direct_doh() {
let client = new_client(vec!["doh:normal.example", "1.1.1.1", "system"]);
let selected = collect_server_strings(&client, true);
assert_eq!(
selected,
vec![
"doh:normal.example".to_string(),
"1.1.1.1:53".to_string(),
"system".to_string()
]
);
}
#[test]
fn parse_doh_http_body_supports_content_length() {
let body = b"\x01\x02\x03\x04";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\nContent-Length: {}\r\n\r\n",
body.len()
);
let mut raw = response.into_bytes();
raw.extend_from_slice(body);
let parsed = DnsClient::parse_doh_http_body(&raw).unwrap();
assert_eq!(parsed, body);
}
#[test]
fn parse_doh_http_body_supports_chunked() {
let response =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nABCD\r\n2\r\nEF\r\n0\r\n\r\n";
let parsed = DnsClient::parse_doh_http_body(response).unwrap();
assert_eq!(parsed, b"ABCDEF");
}
#[test]
fn parse_doh_http_body_rejects_non_200() {
let response = b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 3\r\n\r\nbad".to_vec();
let err = DnsClient::parse_doh_http_body(&response).unwrap_err();
assert!(err
.to_string()
.contains("doh server returned http status 503"));
}
#[test]
fn selector_primary_switches_after_consecutive_failures() {
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
let servers = vec![&s1, &s2];
let mut selector = ServerSelectorState::default();
let initial = selector.select_primary_index(&servers);
assert_eq!(initial, 0);
let key = s1.to_string();
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
for _ in 0..threshold {
selector.mark_failure(&key, true);
}
let selected = selector.select_primary_index(&servers);
assert_eq!(selected, 1);
}
#[test]
fn selector_prefers_lower_latency_in_fallback_order() {
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
let s3 = DnsClient::parse_server("9.9.9.9").unwrap();
let servers = vec![&s1, &s2, &s3];
let mut selector = ServerSelectorState::default();
selector.mark_success(&s2.to_string(), Duration::from_millis(30));
selector.mark_success(&s3.to_string(), Duration::from_millis(450));
let order = selector.fallback_indices(&servers, 0);
assert_eq!(order, vec![1, 2]);
}
#[test]
fn selector_marks_slow_server_as_degraded() {
let server = DnsClient::parse_server("1.1.1.1").unwrap();
let mut selector = ServerSelectorState::default();
let key = server.to_string();
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
let slow_elapsed = Duration::from_millis(*crate::option::DNS_SERVER_SLOW_RESPONSE_MS + 50);
for _ in 0..threshold {
selector.mark_success(&key, slow_elapsed);
}
assert!(selector.is_degraded(&key));
}
}
include!("client/tests.rs");

View File

@@ -0,0 +1,207 @@
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use super::{DnsClient, Resolver, ServerSelectorState};
fn new_client(servers: Vec<&str>) -> DnsClient {
let mut dns = crate::config::Dns::new();
dns.servers = servers.into_iter().map(|s| s.to_string()).collect();
DnsClient::new(&protobuf::MessageField::some(dns)).unwrap()
}
fn collect_server_strings(client: &DnsClient, is_direct_outbound: bool) -> Vec<String> {
client
.collect_servers(is_direct_outbound)
.into_iter()
.map(|server| server.to_string())
.collect()
}
#[test]
fn load_servers_supports_legacy_and_doh_with_ip() {
let mut dns = crate::config::Dns::new();
dns.servers = vec![
"1.1.1.1".to_string(),
"direct:system".to_string(),
"doh:example.com@9.9.9.9".to_string(),
"direct:doh:example.com@8.8.8.8".to_string(),
"doh:example.net".to_string(),
];
let servers = DnsClient::load_servers(&dns).unwrap();
match &servers[0] {
Resolver::Server(addr, false) => assert_eq!(
*addr,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
),
_ => panic!("unexpected resolver"),
}
match &servers[1] {
Resolver::System(true) => {}
_ => panic!("unexpected resolver"),
}
match &servers[2] {
Resolver::DoH(doh) => {
assert_eq!(doh.domain, "example.com");
assert_eq!(
doh.bootstrap_ip,
Some(IpAddr::V4(Ipv4Addr::new(9, 9, 9, 9)))
);
assert!(!doh.is_direct);
}
_ => panic!("unexpected resolver"),
}
match &servers[3] {
Resolver::DoH(doh) => {
assert_eq!(doh.domain, "example.com");
assert_eq!(
doh.bootstrap_ip,
Some(IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)))
);
assert!(doh.is_direct);
}
_ => panic!("unexpected resolver"),
}
match &servers[4] {
Resolver::DoH(doh) => {
assert_eq!(doh.domain, "example.net");
assert_eq!(doh.bootstrap_ip, None);
assert!(!doh.is_direct);
}
_ => panic!("unexpected resolver"),
}
}
#[test]
fn load_servers_ignores_invalid_doh_value_if_any_valid_server_exists() {
let mut dns = crate::config::Dns::new();
dns.servers = vec![
"doh:@1.1.1.1".to_string(),
"direct:doh:example.com@not-an-ip".to_string(),
"doh:example.com#8.8.8.8".to_string(),
"1.1.1.1".to_string(),
];
let servers = DnsClient::load_servers(&dns).unwrap();
assert_eq!(servers.len(), 1);
match &servers[0] {
Resolver::Server(addr, false) => assert_eq!(
*addr,
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 53)
),
_ => panic!("unexpected resolver"),
}
}
#[test]
fn load_servers_rejects_when_all_servers_invalid() {
let mut dns = crate::config::Dns::new();
dns.servers = vec![
"doh:@1.1.1.1".to_string(),
"direct:doh:example.com@not-an-ip".to_string(),
"doh:example.com#8.8.8.8".to_string(),
];
let err = DnsClient::load_servers(&dns).unwrap_err();
assert!(err.to_string().contains("no dns servers"));
}
#[test]
fn collect_servers_includes_direct_doh_for_direct_outbound() {
let client = new_client(vec![
"1.1.1.1",
"doh:normal.example",
"direct:doh:direct.example@8.8.8.8",
]);
let selected = collect_server_strings(&client, true);
assert_eq!(selected, vec!["direct:doh:direct.example@8.8.8.8"]);
}
#[test]
fn collect_servers_fallback_to_normal_keeps_non_direct_doh() {
let client = new_client(vec!["doh:normal.example", "1.1.1.1", "system"]);
let selected = collect_server_strings(&client, true);
assert_eq!(
selected,
vec![
"doh:normal.example".to_string(),
"1.1.1.1:53".to_string(),
"system".to_string()
]
);
}
#[test]
fn parse_doh_http_body_supports_content_length() {
let body = b"\x01\x02\x03\x04";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: application/dns-message\r\nContent-Length: {}\r\n\r\n",
body.len()
);
let mut raw = response.into_bytes();
raw.extend_from_slice(body);
let parsed = DnsClient::parse_doh_http_body(&raw).unwrap();
assert_eq!(parsed, body);
}
#[test]
fn parse_doh_http_body_supports_chunked() {
let response =
b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nABCD\r\n2\r\nEF\r\n0\r\n\r\n";
let parsed = DnsClient::parse_doh_http_body(response).unwrap();
assert_eq!(parsed, b"ABCDEF");
}
#[test]
fn parse_doh_http_body_rejects_non_200() {
let response = b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 3\r\n\r\nbad".to_vec();
let err = DnsClient::parse_doh_http_body(&response).unwrap_err();
assert!(err
.to_string()
.contains("doh server returned http status 503"));
}
#[test]
fn selector_primary_switches_after_consecutive_failures() {
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
let servers = vec![&s1, &s2];
let mut selector = ServerSelectorState::default();
let initial = selector.select_primary_index(&servers);
assert_eq!(initial, 0);
let key = s1.to_string();
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
for _ in 0..threshold {
selector.mark_failure(&key, true);
}
let selected = selector.select_primary_index(&servers);
assert_eq!(selected, 1);
}
#[test]
fn selector_prefers_lower_latency_in_fallback_order() {
let s1 = DnsClient::parse_server("1.1.1.1").unwrap();
let s2 = DnsClient::parse_server("8.8.8.8").unwrap();
let s3 = DnsClient::parse_server("9.9.9.9").unwrap();
let servers = vec![&s1, &s2, &s3];
let mut selector = ServerSelectorState::default();
selector.mark_success(&s2.to_string(), Duration::from_millis(30));
selector.mark_success(&s3.to_string(), Duration::from_millis(450));
let order = selector.fallback_indices(&servers, 0);
assert_eq!(order, vec![1, 2]);
}
#[test]
fn selector_marks_slow_server_as_degraded() {
let server = DnsClient::parse_server("1.1.1.1").unwrap();
let mut selector = ServerSelectorState::default();
let key = server.to_string();
let threshold = (*crate::option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
let slow_elapsed = Duration::from_millis(*crate::option::DNS_SERVER_SLOW_RESPONSE_MS + 50);
for _ in 0..threshold {
selector.mark_success(&key, slow_elapsed);
}
assert!(selector.is_degraded(&key));
}
}

View File

@@ -0,0 +1,217 @@
#[derive(Clone, Debug)]
struct CacheEntry {
pub ips: Vec<IpAddr>,
pub deadline: Instant,
}
#[derive(Clone, Debug)]
pub struct EchCacheEntry {
pub ech_config_list: String,
pub deadline: Instant,
}
#[derive(Clone, Debug)]
struct DohResolver {
domain: String,
bootstrap_ip: Option<IpAddr>,
is_direct: bool,
}
#[derive(Clone, Debug)]
enum Resolver {
Server(SocketAddr, bool),
DoH(DohResolver),
System(bool),
}
#[derive(Clone, Debug, Default)]
struct ServerRuntimeStats {
avg_latency_ms: f64,
samples: u64,
successes: u64,
failures: u64,
timeouts: u64,
consecutive_slow: u32,
consecutive_failures: u32,
}
#[derive(Clone, Debug, Default)]
struct ServerSelectorState {
primary_server: Option<String>,
stats: HashMap<String, ServerRuntimeStats>,
last_reselect_at: Option<Instant>,
}
impl fmt::Display for Resolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Server(addr, direct) => {
if *direct {
write!(f, "direct:{}", addr)
} else {
write!(f, "{}", addr)
}
}
Self::DoH(doh) => {
if doh.is_direct {
write!(f, "direct:doh:{}", doh.domain)?;
} else {
write!(f, "doh:{}", doh.domain)?;
}
if let Some(ip) = doh.bootstrap_ip {
write!(f, "@{}", ip)?;
}
Ok(())
}
Self::System(direct) => {
if *direct {
write!(f, "direct:system")
} else {
write!(f, "system")
}
}
}
}
}
impl ServerSelectorState {
fn score_of(&self, server: &str) -> f64 {
if let Some(stat) = self.stats.get(server) {
let baseline = if stat.samples == 0 {
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
} else {
stat.avg_latency_ms
};
baseline
+ (stat.failures as f64 * 600.0)
+ (stat.timeouts as f64 * 900.0)
+ (stat.consecutive_failures as f64 * 1200.0)
+ (stat.consecutive_slow as f64 * 300.0)
} else {
(*option::DNS_SERVER_SLOW_RESPONSE_MS as f64) / 2.0
}
}
fn is_degraded(&self, server: &str) -> bool {
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
if let Some(stat) = self.stats.get(server) {
(stat.consecutive_failures as usize) >= switch_threshold
|| (stat.consecutive_slow as usize) >= switch_threshold
} else {
false
}
}
fn ensure_candidates(&mut self, servers: &[&Resolver]) {
for server in servers {
self.stats.entry(server.to_string()).or_default();
}
}
fn select_primary_index(&mut self, servers: &[&Resolver]) -> usize {
if servers.len() <= 1 {
if let Some(server) = servers.first() {
self.primary_server = Some(server.to_string());
}
return 0;
}
self.ensure_candidates(servers);
let now = Instant::now();
let reselect_interval =
Duration::from_secs((*option::DNS_SERVER_RESELECT_INTERVAL_SECS).max(1));
let should_reselect = self
.last_reselect_at
.map(|last| now.saturating_duration_since(last) >= reselect_interval)
.unwrap_or(true);
let current_idx = self.primary_server.as_ref().and_then(|primary| {
servers
.iter()
.position(|server| server.to_string() == *primary)
});
if let Some(idx) = current_idx {
let current_key = servers[idx].to_string();
if !should_reselect && !self.is_degraded(&current_key) {
return idx;
}
}
let mut best_idx = 0usize;
let mut best_score = f64::MAX;
for (idx, server) in servers.iter().enumerate() {
let score = self.score_of(&server.to_string());
if score < best_score {
best_score = score;
best_idx = idx;
}
}
self.primary_server = Some(servers[best_idx].to_string());
self.last_reselect_at = Some(now);
best_idx
}
fn fallback_indices(&self, servers: &[&Resolver], preferred_idx: usize) -> Vec<usize> {
let mut candidates: Vec<usize> = (0..servers.len())
.filter(|idx| *idx != preferred_idx)
.collect();
candidates.sort_by(|a, b| {
let sa = self.score_of(&servers[*a].to_string());
let sb = self.score_of(&servers[*b].to_string());
sa.partial_cmp(&sb).unwrap_or(std::cmp::Ordering::Equal)
});
candidates
}
fn mark_success(&mut self, server: &str, elapsed: Duration) {
let stat = self.stats.entry(server.to_owned()).or_default();
let elapsed_ms = elapsed.as_millis() as f64;
stat.successes = stat.successes.saturating_add(1);
stat.samples = stat.samples.saturating_add(1);
if stat.samples == 1 {
stat.avg_latency_ms = elapsed_ms;
} else {
stat.avg_latency_ms = stat.avg_latency_ms * 0.8 + elapsed_ms * 0.2;
}
let slow_threshold = (*option::DNS_SERVER_SLOW_RESPONSE_MS).max(1) as f64;
if elapsed_ms >= slow_threshold {
stat.consecutive_slow = stat.consecutive_slow.saturating_add(1);
} else {
stat.consecutive_slow = 0;
}
stat.consecutive_failures = 0;
if self.primary_server.is_none() {
self.primary_server = Some(server.to_owned());
}
}
fn mark_failure(&mut self, server: &str, is_timeout: bool) {
let stat = self.stats.entry(server.to_owned()).or_default();
stat.failures = stat.failures.saturating_add(1);
if is_timeout {
stat.timeouts = stat.timeouts.saturating_add(1);
}
stat.consecutive_failures = stat.consecutive_failures.saturating_add(1);
let switch_threshold = (*option::DNS_SERVER_SWITCH_THRESHOLD).max(1);
if self.primary_server.as_deref() == Some(server)
&& (stat.consecutive_failures as usize) >= switch_threshold
{
self.primary_server = None;
}
}
fn set_primary(&mut self, server: &str) {
self.primary_server = Some(server.to_owned());
self.last_reselect_at = Some(Instant::now());
}
}
pub struct DnsClient {
dispatcher: Option<Weak<Dispatcher>>,
servers: Vec<Resolver>,
hosts: HashMap<String, Vec<IpAddr>>,
ipv4_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
ipv6_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
ech_cache: Arc<TokioMutex<LruCache<String, EchCacheEntry>>>,
ech_query_locks: Arc<TokioMutex<HashMap<String, Arc<TokioMutex<()>>>>>,
selector_state: Arc<Mutex<ServerSelectorState>>,
}

3
leaf/src/app/dns/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
mod client;
pub use client::*;

View File

@@ -3,7 +3,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
pub mod dispatcher;
pub mod dns_client;
pub mod dns;
pub mod healthcheck;
pub mod inbound;
pub mod logger;
@@ -17,6 +17,10 @@ pub mod api;
pub mod fake_dns;
pub type SyncDnsClient = Arc<RwLock<dns_client::DnsClient>>;
pub mod dns_client {
pub use super::dns::*;
}
pub type SyncDnsClient = Arc<RwLock<dns::DnsClient>>;
pub type SyncStatManager = Arc<RwLock<stat_manager::StatManager>>;

View File

@@ -18,7 +18,7 @@ use notify::{
};
use app::{
dispatcher::Dispatcher, dns_client::DnsClient, inbound::manager::InboundManager,
dispatcher::Dispatcher, dns::DnsClient, inbound::manager::InboundManager,
nat_manager::NatManager, outbound::manager::OutboundManager, router::Router,
};

View File

@@ -653,7 +653,7 @@ mod tests {
use protobuf::MessageField;
use tokio::sync::RwLock;
use crate::app::{dns_client::DnsClient, SyncDnsClient};
use crate::app::{dns::DnsClient, SyncDnsClient};
#[cfg(feature = "rustls-tls-aws-lc")]
use crate::session::Session;

View File

@@ -10,7 +10,7 @@ use tokio::sync::RwLock;
use tokio::time::timeout;
use crate::{
app::{dns_client::DnsClient, outbound::manager::OutboundManager, SyncDnsClient},
app::{dns::DnsClient, outbound::manager::OutboundManager, SyncDnsClient},
config::Config,
proxy::*,
session::*,