Dispatch internal DNS queries

This commit is contained in:
eric
2022-09-17 13:06:52 +08:00
parent 792c59c2db
commit 56b3464a07
8 changed files with 133 additions and 33 deletions

View File

@@ -135,6 +135,7 @@ socket2 = "0.4"
directories = "4.0"
async-ffi = "0.2"
libloading = "0.7"
async-recursion = "1.0"
# config-json
serde_json = { version = "1.0", features = ["raw_value"], optional = true }

View File

@@ -3,6 +3,7 @@ use std::io::{self, ErrorKind};
use std::sync::Arc;
use std::time::Duration;
use async_recursion::async_recursion;
use log::*;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::RwLock;
@@ -261,6 +262,7 @@ impl Dispatcher {
}
}
#[async_recursion]
pub async fn dispatch_datagram(
&self,
mut sess: Session,

View File

@@ -1,7 +1,7 @@
use std::collections::HashMap;
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use std::time::{Duration, Instant};
use anyhow::{anyhow, Result};
@@ -18,7 +18,7 @@ use trust_dns_proto::{
rr::{record_data::RData, record_type::RecordType, Name},
};
use crate::{option, proxy::UdpConnector};
use crate::{app::dispatcher::Dispatcher, option, proxy::*, session::*};
#[derive(Clone, Debug)]
struct CacheEntry {
@@ -28,6 +28,7 @@ struct CacheEntry {
}
pub struct DnsClient {
dispatcher: Option<Weak<Dispatcher>>,
servers: Vec<SocketAddr>,
hosts: HashMap<String, Vec<IpAddr>>,
ipv4_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
@@ -79,7 +80,8 @@ impl DnsClient {
*option::DNS_CACHE_SIZE,
)));
Ok(DnsClient {
Ok(Self {
dispatcher: None,
servers,
hosts,
ipv4_cache,
@@ -87,6 +89,10 @@ impl DnsClient {
})
}
pub fn replace_dispatcher(&mut self, dispatcher: Weak<Dispatcher>) {
self.dispatcher.replace(dispatcher);
}
pub fn reload(&mut self, dns: &protobuf::SingularPtrField<crate::config::Dns>) -> Result<()> {
let dns = if let Some(dns) = dns.as_ref() {
dns
@@ -164,21 +170,42 @@ impl DnsClient {
async fn query_task(
&self,
is_direct: bool,
request: Vec<u8>,
host: &str,
server: &SocketAddr,
) -> Result<CacheEntry> {
let socket = self.new_udp_socket(server).await?;
let socket = if is_direct {
let socket = self.new_udp_socket(server).await?;
Box::new(StdOutboundDatagram::new(socket))
} else {
if let Some(dispatcher_weak) = self.dispatcher.as_ref() {
let sess = Session {
network: Network::Udp,
destination: SocksAddr::from(server),
..Default::default()
};
if let Some(dispatcher) = dispatcher_weak.upgrade() {
dispatcher.dispatch_datagram(sess).await?
} else {
return Err(anyhow!("dispatcher is deallocated"));
}
} else {
return Err(anyhow!("could not find a dispatcher"));
}
};
let (mut r, mut s) = socket.split();
let server = SocksAddr::from(server);
let mut last_err = None;
for _i in 0..*option::MAX_DNS_RETRIES {
debug!("looking up host {} on {}", host, server);
let start = tokio::time::Instant::now();
match socket.send_to(&request, server).await {
match s.send_to(&request, &server).await {
Ok(_) => {
let mut buf = vec![0u8; 512];
match timeout(
Duration::from_secs(*option::DNS_TIMEOUT),
socket.recv_from(&mut buf),
r.recv_from(&mut buf),
)
.await
{
@@ -362,6 +389,14 @@ impl DnsClient {
}
pub async fn lookup(&self, host: &String) -> Result<Vec<IpAddr>> {
self._lookup(host, false).await
}
pub async fn direct_lookup(&self, host: &String) -> Result<Vec<IpAddr>> {
self._lookup(host, true).await
}
pub async fn _lookup(&self, host: &String, is_direct: bool) -> Result<Vec<IpAddr>> {
if let Ok(ip) = host.parse::<IpAddr>() {
return Ok(vec![ip]);
}
@@ -413,7 +448,7 @@ impl DnsClient {
};
let mut tasks = Vec::new();
for server in &self.servers {
let t = self.query_task(msg_buf.clone(), host, server);
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
tasks.push(Box::pin(t));
}
let query_task = select_ok(tasks.into_iter());
@@ -426,7 +461,7 @@ impl DnsClient {
};
let mut tasks = Vec::new();
for server in &self.servers {
let t = self.query_task(msg_buf.clone(), host, server);
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
tasks.push(Box::pin(t));
}
let query_task = select_ok(tasks.into_iter());
@@ -440,7 +475,7 @@ impl DnsClient {
};
let mut tasks = Vec::new();
for server in &self.servers {
let t = self.query_task(msg_buf.clone(), host, server);
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
tasks.push(Box::pin(t));
}
let query_task = select_ok(tasks.into_iter());
@@ -453,7 +488,7 @@ impl DnsClient {
};
let mut tasks = Vec::new();
for server in &self.servers {
let t = self.query_task(msg_buf.clone(), host, server);
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
tasks.push(Box::pin(t));
}
let query_task = select_ok(tasks.into_iter());
@@ -467,7 +502,7 @@ impl DnsClient {
};
let mut tasks = Vec::new();
for server in &self.servers {
let t = self.query_task(msg_buf.clone(), host, server);
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
tasks.push(Box::pin(t));
}
let query_task = select_ok(tasks.into_iter());

View File

@@ -23,7 +23,7 @@ impl Resolver {
dns_client
.read()
.await
.lookup(address)
.direct_lookup(address)
.map_err(|e| anyhow!("lookup {} failed: {}", address, e))
.await?
};

View File

@@ -415,6 +415,16 @@ pub fn start(rt_id: RuntimeId, opts: StartOptions) -> Result<(), Error> {
#[cfg(feature = "stat")]
stat_manager.clone(),
));
let dispatcher_weak = Arc::downgrade(&dispatcher);
let dns_client_cloned = dns_client.clone();
rt.block_on(async move {
dns_client_cloned
.write()
.await
.replace_dispatcher(dispatcher_weak);
});
let nat_manager = Arc::new(NatManager::new(dispatcher.clone()));
let inbound_manager =
InboundManager::new(&config.inbounds, dispatcher, nat_manager).map_err(Error::Config)?;
@@ -548,13 +558,15 @@ pub fn start(rt_id: RuntimeId, opts: StartOptions) -> Result<(), Error> {
#[cfg(all(feature = "inbound-tun", any(target_os = "macos", target_os = "linux")))]
sys::post_tun_completion_setup(&net_info);
rt.shutdown_background();
drop(inbound_manager);
RUNTIME_MANAGER
.lock()
.map_err(|_| Error::RuntimeManager)?
.remove(&rt_id);
rt.shutdown_background();
log::trace!("removed runtime {}", &rt_id);
Ok(())

View File

@@ -15,6 +15,59 @@ use crate::{
use super::*;
/// An outbound datagram wraps a normal UDP socket and used as a normal UDP socket.
pub struct StdOutboundDatagram {
inner: UdpSocket,
}
impl StdOutboundDatagram {
pub fn new(inner: UdpSocket) -> Self {
Self { inner }
}
}
impl OutboundDatagram for StdOutboundDatagram {
fn split(
self: Box<Self>,
) -> (
Box<dyn OutboundDatagramRecvHalf>,
Box<dyn OutboundDatagramSendHalf>,
) {
let r = Arc::new(self.inner);
let s = r.clone();
(
Box::new(StdOutboundDatagramRecvHalf(r)),
Box::new(StdOutboundDatagramSendHalf(s)),
)
}
}
pub struct StdOutboundDatagramRecvHalf(Arc<UdpSocket>);
#[async_trait]
impl OutboundDatagramRecvHalf for StdOutboundDatagramRecvHalf {
async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocksAddr)> {
match self.0.recv_from(buf).await {
Ok((n, a)) => Ok((n, SocksAddr::Ip(unmapped_ipv4(a)))),
Err(e) => Err(e),
}
}
}
pub struct StdOutboundDatagramSendHalf(Arc<UdpSocket>);
#[async_trait]
impl OutboundDatagramSendHalf for StdOutboundDatagramSendHalf {
async fn send_to(&mut self, buf: &[u8], target: &SocksAddr) -> io::Result<usize> {
// The type does not accept domain name.
self.0.send_to(buf, target.must_ip()).await
}
async fn close(&mut self) -> io::Result<()> {
Ok(())
}
}
/// An outbound datagram simply wraps a UDP socket.
pub struct SimpleOutboundDatagram {
inner: UdpSocket,

View File

@@ -84,6 +84,7 @@ pub mod ws;
pub use datagram::{
SimpleInboundDatagram, SimpleInboundDatagramRecvHalf, SimpleInboundDatagramSendHalf,
SimpleOutboundDatagram, SimpleOutboundDatagramRecvHalf, SimpleOutboundDatagramSendHalf,
StdOutboundDatagram,
};
#[derive(Error, Debug)]
@@ -211,24 +212,20 @@ async fn bind_socket<T: BindSocket>(socket: &T, indicator: &SocketAddr) -> io::R
}
let ret = match indicator {
SocketAddr::V4(..) => {
libc::setsockopt(
socket.as_raw_fd(),
libc::IPPROTO_IP,
libc::IP_BOUND_IF,
&ifidx as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
)
}
SocketAddr::V6(..) => {
libc::setsockopt(
socket.as_raw_fd(),
libc::IPPROTO_IPV6,
libc::IPV6_BOUND_IF,
&ifidx as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
)
}
SocketAddr::V4(..) => libc::setsockopt(
socket.as_raw_fd(),
libc::IPPROTO_IP,
libc::IP_BOUND_IF,
&ifidx as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
),
SocketAddr::V6(..) => libc::setsockopt(
socket.as_raw_fd(),
libc::IPPROTO_IPV6,
libc::IPV6_BOUND_IF,
&ifidx as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
),
};
if ret == -1 {
last_err = Some(io::Error::last_os_error());

View File

@@ -154,9 +154,9 @@ impl SocksAddr {
Self::Ip("[::]:0".parse().unwrap())
}
pub fn must_ip(self) -> SocketAddr {
pub fn must_ip(&self) -> &SocketAddr {
match self {
SocksAddr::Ip(a) => a,
SocksAddr::Ip(ref a) => a,
_ => {
panic!("assert SocksAddr as SocketAddr failed");
}