Resolve direct domains with direct DNS servers

This commit is contained in:
eric
2026-03-01 09:57:37 +08:00
parent d82a54e34e
commit a911e542f6
7 changed files with 256 additions and 114 deletions

View File

@@ -129,8 +129,8 @@ fn log_request(sess: &Session, outbound_tag: &str, handshake_time: Option<u128>)
}
pub struct Dispatcher {
outbound_manager: Arc<RwLock<OutboundManager>>,
router: Arc<RwLock<Router>>,
pub(crate) outbound_manager: Arc<RwLock<OutboundManager>>,
pub(crate) router: Arc<RwLock<Router>>,
dns_client: SyncDnsClient,
stat_manager: SyncStatManager,
dns_sniffer: DnsSniffer,
@@ -448,4 +448,12 @@ impl Dispatcher {
}
}
}
pub async fn is_direct_outbound(&self, tag: &str) -> bool {
if let Some(h) = self.outbound_manager.read().await.get(tag) {
h.is_direct()
} else {
false
}
}
}

View File

@@ -7,6 +7,7 @@ use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use anyhow::{anyhow, Result};
use async_recursion::async_recursion;
use futures::future::select_ok;
use hickory_proto::{
op::{
@@ -31,20 +32,28 @@ struct CacheEntry {
#[derive(Debug)]
enum Resolver {
Server(SocketAddr),
System,
Server(SocketAddr, bool),
System(bool),
}
impl fmt::Display for Resolver {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
Self::Server(addr) => addr.to_string(),
Self::System => "system".to_string(),
match self {
Self::Server(addr, direct) => {
if *direct {
write!(f, "direct:{}", addr)
} else {
write!(f, "{}", addr)
}
}
)
Self::System(direct) => {
if *direct {
write!(f, "direct:system")
} else {
write!(f, "system")
}
}
}
}
}
@@ -60,15 +69,23 @@ impl DnsClient {
fn load_servers(dns: &crate::config::Dns) -> Result<Vec<Resolver>> {
let mut servers = Vec::new();
for server in dns.servers.iter() {
if server.to_lowercase() == "system" {
servers.push(Resolver::System);
let (server, is_direct) = if server.to_lowercase().starts_with("direct:") {
(&server[7..], true)
} else {
servers.push(Resolver::Server(SocketAddr::new(
server.parse::<IpAddr>()?,
53,
)));
(server.as_str(), false)
};
if server.to_lowercase() == "system" {
servers.push(Resolver::System(is_direct));
} else {
servers.push(Resolver::Server(
SocketAddr::new(server.parse::<IpAddr>()?, 53),
is_direct,
));
}
}
for server in &servers {
debug!("loaded dns server: {}", server);
}
if servers.is_empty() {
return Err(anyhow!("no dns servers"));
}
@@ -196,67 +213,31 @@ impl DnsClient {
}
}
async fn resolve_with_server(
async fn query_with_socket(
&self,
is_direct: bool,
socket: Box<dyn OutboundDatagram>,
request: Vec<u8>,
span: tracing::Span,
host: &str,
server: &SocketAddr,
resolver: &Resolver,
) -> Result<CacheEntry> {
let (socket, span) = if is_direct {
debug!("direct lookup");
let socket = self.new_udp_socket(server).await?;
(
Box::new(StdOutboundDatagram::new(socket)) as Box<dyn OutboundDatagram>,
tracing::Span::current(),
)
} else {
debug!("dispatched lookup");
if let Some(dispatcher_weak) = self.dispatcher.as_ref() {
// The source address will be used to determine which address the
// underlying socket will bind.
let source = match server {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
};
let sess = Session {
network: Network::Udp,
source,
destination: SocksAddr::from(server),
inbound_tag: "dnsclient".to_string(),
..Default::default()
};
let span = sess.span();
if let Some(dispatcher) = dispatcher_weak.upgrade() {
(
dispatcher
.dispatch_datagram(sess)
.instrument(span.clone())
.await?,
span,
)
} else {
return Err(anyhow!("dispatcher is deallocated"));
}
} else {
return Err(anyhow!("could not find a dispatcher"));
}
let resolver_addr = match resolver {
Resolver::Server(addr, _) => SocksAddr::from(*addr),
_ => SocksAddr::any_ipv4(),
};
async move {
let (mut r, mut s) = socket.split();
let server = SocksAddr::from(server);
for i in 0..*option::MAX_DNS_RETRIES {
debug!(
"looking up host={} server={} ({}/{})",
host,
server,
resolver,
i + 1,
*option::MAX_DNS_RETRIES
);
let start = tokio::time::Instant::now();
if let Err(err) = s.send_to(&request, &server).await {
if let Err(err) = s.send_to(&request, &resolver_addr).await {
debug!("send DNS query failed: {}", err);
continue;
}
@@ -270,11 +251,11 @@ impl DnsClient {
{
Ok(Ok((n, _))) => n,
Ok(Err(e)) => {
debug!("recv DNS response from {} failed: {}", server, e);
debug!("recv DNS response from {} failed: {}", resolver, e);
continue;
}
Err(e) => {
debug!("recv DNS response from {} failed: {}", server, e);
debug!("recv DNS response from {} failed: {}", resolver, e);
continue;
}
};
@@ -282,7 +263,7 @@ impl DnsClient {
let resp = match Message::from_vec(&buf[..n]) {
Ok(resp) => resp,
Err(err) => {
debug!("parse DNS message from {} failed: {}", server, err);
debug!("parse DNS message from {} failed: {}", resolver, err);
break;
}
};
@@ -290,7 +271,7 @@ impl DnsClient {
if resp.response_code() != ResponseCode::NoError {
debug!(
"error DNS response from {} for {}: {}",
server,
resolver,
host,
resp.response_code()
);
@@ -299,7 +280,6 @@ impl DnsClient {
let mut ips = Vec::new();
for ans in resp.answers() {
// TODO checks?
if let Some(data) = ans.data() {
match data {
RData::A(ip) => {
@@ -314,10 +294,7 @@ impl DnsClient {
}
if ips.is_empty() {
// response with 0 records
//
// TODO Not sure how to due with this.
debug!("no records in DNS response from {} for {}", server, host);
debug!("no records in DNS response from {} for {}", resolver, host);
break;
}
@@ -325,7 +302,7 @@ impl DnsClient {
let ttl = resp.answers().iter().next().unwrap().ttl();
debug!(
"received from server={} ttl={} elapsed={}ms ips={:?}",
server,
resolver,
ttl,
elapsed.as_millis(),
&ips,
@@ -346,40 +323,87 @@ impl DnsClient {
.await
}
async fn resolve_with_system_resolver(&self, host: &str, ty: RecordType) -> Result<CacheEntry> {
debug!("resolving {} using system resolver", host);
use std::net::ToSocketAddrs;
let addr = format!("{}:0", host);
let start = std::time::Instant::now();
let ips = tokio::task::spawn_blocking(move || addr.to_socket_addrs())
async fn resolve_with_server(
&self,
is_direct: bool,
request: Vec<u8>,
host: &str,
resolver: &Resolver,
) -> Result<CacheEntry> {
let (socket, span) = match resolver {
Resolver::Server(server, _) if is_direct => {
debug!("direct lookup");
let socket = self.new_udp_socket(server).await?;
(
Box::new(StdOutboundDatagram::new(socket)) as Box<dyn OutboundDatagram>,
tracing::Span::current(),
)
}
Resolver::Server(server, _) => {
debug!("dispatched lookup");
if let Some(dispatcher_weak) = self.dispatcher.as_ref() {
// The source address will be used to determine which address the
// underlying socket will bind.
let source = match server {
SocketAddr::V4(_) => SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
SocketAddr::V6(_) => SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0),
};
let sess = Session {
network: Network::Udp,
source,
destination: SocksAddr::from(server),
inbound_tag: "dnsclient".to_string(),
..Default::default()
};
let span = sess.span();
if let Some(dispatcher) = dispatcher_weak.upgrade() {
(
dispatcher
.dispatch_datagram(sess)
.instrument(span.clone())
.await?,
span,
)
} else {
return Err(anyhow!("dispatcher is gone"));
}
} else {
return Err(anyhow!("no dispatcher"));
}
}
Resolver::System(_) => {
debug!("resolving {} using system resolver", host);
use std::net::ToSocketAddrs;
let addr = format!("{}:0", host);
let start = std::time::Instant::now();
let ips = tokio::task::spawn_blocking(move || {
addr.to_socket_addrs()
.map(|iter| iter.map(|x| x.ip()).collect::<Vec<_>>())
})
.await
.map_err(|e| anyhow!("spawn blocking failed: {}", e))?
.map_err(|e| anyhow!("system resolver failed: {}", e))?;
debug!(
"resolved ips={:?} for domain={} from system resolver in {} ms",
&ips,
host,
start.elapsed().as_millis(),
);
if ips.is_empty() {
return Err(anyhow!("no records from system resolver"));
}
return Ok(CacheEntry {
ips,
deadline: Instant::now() + Duration::from_secs(60),
});
}
};
self.query_with_socket(socket, request, span, host, resolver)
.await
.map_err(|e| anyhow!("spawn blocking failed: {}", e))?
.map_err(|e| anyhow!("system resolver failed: {}", e))?
.map(|x| x.ip())
.filter(|ip| match ty {
RecordType::A => ip.is_ipv4(),
RecordType::AAAA => ip.is_ipv6(),
_ => true,
})
.collect::<Vec<_>>();
debug!(
"resolved ips={:?} for domain={} from system resolver in {} ms",
&ips,
host,
start.elapsed().as_millis(),
);
trace!("ips for {}:\n{:?}", host, &ips);
if ips.is_empty() {
return Err(anyhow!("no records"));
}
// System resolver result should be considered valid for some time,
// but we don't have TTL. Using 60s as a fallback.
let deadline = std::time::Instant::now() + Duration::from_secs(60);
Ok(CacheEntry { ips, deadline })
}
async fn query_task(
@@ -387,18 +411,34 @@ impl DnsClient {
is_direct: bool,
request: Vec<u8>,
host: &str,
server: &Resolver,
resolver: &Resolver,
ty: RecordType,
) -> Result<CacheEntry> {
match server {
Resolver::System => self.resolve_with_system_resolver(host, ty).await,
Resolver::Server(addr) => {
self.resolve_with_server(is_direct, request, host, addr)
.await
let res = match timeout(
Duration::from_secs(*option::DNS_TIMEOUT),
self.resolve_with_server(is_direct, request, host, resolver),
)
.await
{
Ok(res) => res,
Err(_) => Err(anyhow!("query {} {} timeout", host, ty)),
};
match res {
Ok(entry) => {
trace!("query {} {} success with server {}", host, ty, resolver);
Ok(entry)
}
Err(e) => {
debug!(
"query {} {} failed with server {}: {}",
host, ty, resolver, e
);
Err(e)
}
}
}
#[async_recursion]
async fn query_record_type(
&self,
is_direct: bool,
@@ -411,11 +451,75 @@ impl DnsClient {
Ok(b) => b,
Err(e) => return Err(anyhow!("encode message to buffer failed: {}", e)),
};
let mut is_direct_outbound = false;
if let Some(dispatcher_weak) = self.dispatcher.as_ref() {
if let Some(dispatcher) = dispatcher_weak.upgrade() {
let dest = match SocksAddr::try_from((host.to_owned(), 0)) {
Ok(d) => d,
Err(e) => return Err(anyhow!("invalid host {}: {}", host, e)),
};
let sess = Session {
destination: dest,
skip_resolve: true,
..Default::default()
};
if let Ok(Some(tag)) = dispatcher.router.read().await.pick_route(&sess).await {
is_direct_outbound = dispatcher.is_direct_outbound(tag).await;
}
}
}
let mut tasks = Vec::new();
for server in &self.servers {
let mut servers = Vec::new();
if is_direct_outbound {
for server in &self.servers {
match server {
Resolver::Server(_, true) | Resolver::System(true) => {
servers.push(server);
}
_ => (),
}
}
if servers.is_empty() {
debug!("no direct dns servers for direct outbound, fallback to normal servers");
for server in &self.servers {
match server {
Resolver::Server(_, false) | Resolver::System(false) => {
servers.push(server);
}
_ => (),
}
}
}
} else {
for server in &self.servers {
match server {
Resolver::Server(_, false) | Resolver::System(false) => {
servers.push(server);
}
_ => (),
}
}
}
if servers.is_empty() {
// If still empty, use all servers as a last resort
for server in &self.servers {
servers.push(server);
}
}
for server in servers {
let t = self.query_task(is_direct, msg_buf.clone(), host, server, ty);
tasks.push(Box::pin(t));
}
if tasks.is_empty() {
return Err(anyhow!("no dns servers available for query"));
}
let (entry, _) = select_ok(tasks.into_iter()).await?;
Ok(entry)
}
@@ -538,6 +642,7 @@ impl DnsClient {
self._lookup(host, true).await
}
#[async_recursion]
pub async fn _lookup(&self, host: &String, is_direct: bool) -> Result<Vec<IpAddr>> {
self._lookup_inner(host, is_direct).await
}

View File

@@ -120,6 +120,7 @@ impl OutboundManager {
.tag(tag.clone())
.stream_handler(Arc::new(direct::StreamHandler))
.datagram_handler(Arc::new(direct::DatagramHandler))
.is_direct(true)
.build(),
#[cfg(feature = "outbound-drop")]
"drop" => HandlerBuilder::default()

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use anyhow::anyhow;
use anyhow::Result;
use async_recursion::async_recursion;
use cidr::IpCidr;
use futures::TryFutureExt;
use maxminddb::geoip2::Country;
@@ -563,6 +564,7 @@ impl Router {
Ok(())
}
#[async_recursion]
pub async fn pick_route<'a>(&'a self, sess: &'a Session) -> Result<Option<&'a String>> {
let effective_dest = &sess.destination;
for rule in &self.rules {
@@ -570,7 +572,7 @@ impl Router {
return Ok(Some(&rule.target));
}
}
if effective_dest.is_domain() && self.domain_resolve {
if effective_dest.is_domain() && self.domain_resolve && !sess.skip_resolve {
debug!("resolve routing domain={:?}", effective_dest.domain());
let ips = {
self.dns_client

View File

@@ -555,6 +555,9 @@ pub trait BaseHandler: Tag + Send + Sync + Unpin {}
pub trait OutboundHandler: BaseHandler {
fn stream(&self) -> io::Result<&AnyOutboundStreamHandler>;
fn datagram(&self) -> io::Result<&AnyOutboundDatagramHandler>;
fn is_direct(&self) -> bool {
false
}
}
pub type AnyOutboundHandler = Arc<dyn OutboundHandler>;

View File

@@ -9,6 +9,7 @@ pub struct Handler {
tag: String,
stream_handler: Option<AnyOutboundStreamHandler>,
datagram_handler: Option<AnyOutboundDatagramHandler>,
is_direct: bool,
}
impl Handler {
@@ -16,11 +17,13 @@ impl Handler {
tag: String,
stream_handler: Option<AnyOutboundStreamHandler>,
datagram_handler: Option<AnyOutboundDatagramHandler>,
is_direct: bool,
) -> Arc<Self> {
Arc::new(Handler {
tag,
stream_handler,
datagram_handler,
is_direct,
})
}
}
@@ -39,6 +42,10 @@ impl OutboundHandler for Handler {
.as_ref()
.ok_or_else(|| io::Error::other("no udp handler"))
}
fn is_direct(&self) -> bool {
self.is_direct
}
}
impl Tag for Handler {
@@ -51,6 +58,7 @@ pub struct HandlerBuilder {
tag: String,
stream_handler: Option<AnyOutboundStreamHandler>,
datagram_handler: Option<AnyOutboundDatagramHandler>,
is_direct: bool,
}
impl HandlerBuilder {
@@ -59,6 +67,7 @@ impl HandlerBuilder {
tag: "".to_string(),
stream_handler: None,
datagram_handler: None,
is_direct: false,
}
}
@@ -77,8 +86,18 @@ impl HandlerBuilder {
self
}
pub fn is_direct(mut self, v: bool) -> Self {
self.is_direct = v;
self
}
pub fn build(self) -> AnyOutboundHandler {
Handler::new(self.tag, self.stream_handler, self.datagram_handler)
Handler::new(
self.tag,
self.stream_handler,
self.datagram_handler,
self.is_direct,
)
}
}

View File

@@ -110,6 +110,8 @@ pub struct Session {
pub dns_sniffed_domain: Option<String>,
/// Shared state to coordinate XTLS vision read raw mode.
pub vision_read_raw: std::sync::Arc<std::sync::atomic::AtomicBool>,
/// Skip domain resolution during routing.
pub skip_resolve: bool,
}
impl Clone for Session {
@@ -130,6 +132,7 @@ impl Clone for Session {
http_sniffed_domain: self.http_sniffed_domain.clone(),
dns_sniffed_domain: self.dns_sniffed_domain.clone(),
vision_read_raw: self.vision_read_raw.clone(),
skip_resolve: self.skip_resolve,
}
}
}
@@ -152,6 +155,7 @@ impl Default for Session {
http_sniffed_domain: None,
dns_sniffed_domain: None,
vision_read_raw: std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)),
skip_resolve: false,
}
}
}