Dispatch internal DNS queries
This commit is contained in:
@@ -135,6 +135,7 @@ socket2 = "0.4"
|
||||
directories = "4.0"
|
||||
async-ffi = "0.2"
|
||||
libloading = "0.7"
|
||||
async-recursion = "1.0"
|
||||
|
||||
# config-json
|
||||
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::io::{self, ErrorKind};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_recursion::async_recursion;
|
||||
use log::*;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::RwLock;
|
||||
@@ -261,6 +262,7 @@ impl Dispatcher {
|
||||
}
|
||||
}
|
||||
|
||||
#[async_recursion]
|
||||
pub async fn dispatch_datagram(
|
||||
&self,
|
||||
mut sess: Session,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Weak};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
@@ -18,7 +18,7 @@ use trust_dns_proto::{
|
||||
rr::{record_data::RData, record_type::RecordType, Name},
|
||||
};
|
||||
|
||||
use crate::{option, proxy::UdpConnector};
|
||||
use crate::{app::dispatcher::Dispatcher, option, proxy::*, session::*};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct CacheEntry {
|
||||
@@ -28,6 +28,7 @@ struct CacheEntry {
|
||||
}
|
||||
|
||||
pub struct DnsClient {
|
||||
dispatcher: Option<Weak<Dispatcher>>,
|
||||
servers: Vec<SocketAddr>,
|
||||
hosts: HashMap<String, Vec<IpAddr>>,
|
||||
ipv4_cache: Arc<TokioMutex<LruCache<String, CacheEntry>>>,
|
||||
@@ -79,7 +80,8 @@ impl DnsClient {
|
||||
*option::DNS_CACHE_SIZE,
|
||||
)));
|
||||
|
||||
Ok(DnsClient {
|
||||
Ok(Self {
|
||||
dispatcher: None,
|
||||
servers,
|
||||
hosts,
|
||||
ipv4_cache,
|
||||
@@ -87,6 +89,10 @@ impl DnsClient {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn replace_dispatcher(&mut self, dispatcher: Weak<Dispatcher>) {
|
||||
self.dispatcher.replace(dispatcher);
|
||||
}
|
||||
|
||||
pub fn reload(&mut self, dns: &protobuf::SingularPtrField<crate::config::Dns>) -> Result<()> {
|
||||
let dns = if let Some(dns) = dns.as_ref() {
|
||||
dns
|
||||
@@ -164,21 +170,42 @@ impl DnsClient {
|
||||
|
||||
async fn query_task(
|
||||
&self,
|
||||
is_direct: bool,
|
||||
request: Vec<u8>,
|
||||
host: &str,
|
||||
server: &SocketAddr,
|
||||
) -> Result<CacheEntry> {
|
||||
let socket = self.new_udp_socket(server).await?;
|
||||
let socket = if is_direct {
|
||||
let socket = self.new_udp_socket(server).await?;
|
||||
Box::new(StdOutboundDatagram::new(socket))
|
||||
} else {
|
||||
if let Some(dispatcher_weak) = self.dispatcher.as_ref() {
|
||||
let sess = Session {
|
||||
network: Network::Udp,
|
||||
destination: SocksAddr::from(server),
|
||||
..Default::default()
|
||||
};
|
||||
if let Some(dispatcher) = dispatcher_weak.upgrade() {
|
||||
dispatcher.dispatch_datagram(sess).await?
|
||||
} else {
|
||||
return Err(anyhow!("dispatcher is deallocated"));
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!("could not find a dispatcher"));
|
||||
}
|
||||
};
|
||||
let (mut r, mut s) = socket.split();
|
||||
let server = SocksAddr::from(server);
|
||||
let mut last_err = None;
|
||||
for _i in 0..*option::MAX_DNS_RETRIES {
|
||||
debug!("looking up host {} on {}", host, server);
|
||||
let start = tokio::time::Instant::now();
|
||||
match socket.send_to(&request, server).await {
|
||||
match s.send_to(&request, &server).await {
|
||||
Ok(_) => {
|
||||
let mut buf = vec![0u8; 512];
|
||||
match timeout(
|
||||
Duration::from_secs(*option::DNS_TIMEOUT),
|
||||
socket.recv_from(&mut buf),
|
||||
r.recv_from(&mut buf),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -362,6 +389,14 @@ impl DnsClient {
|
||||
}
|
||||
|
||||
pub async fn lookup(&self, host: &String) -> Result<Vec<IpAddr>> {
|
||||
self._lookup(host, false).await
|
||||
}
|
||||
|
||||
pub async fn direct_lookup(&self, host: &String) -> Result<Vec<IpAddr>> {
|
||||
self._lookup(host, true).await
|
||||
}
|
||||
|
||||
pub async fn _lookup(&self, host: &String, is_direct: bool) -> Result<Vec<IpAddr>> {
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
return Ok(vec![ip]);
|
||||
}
|
||||
@@ -413,7 +448,7 @@ impl DnsClient {
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for server in &self.servers {
|
||||
let t = self.query_task(msg_buf.clone(), host, server);
|
||||
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
|
||||
tasks.push(Box::pin(t));
|
||||
}
|
||||
let query_task = select_ok(tasks.into_iter());
|
||||
@@ -426,7 +461,7 @@ impl DnsClient {
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for server in &self.servers {
|
||||
let t = self.query_task(msg_buf.clone(), host, server);
|
||||
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
|
||||
tasks.push(Box::pin(t));
|
||||
}
|
||||
let query_task = select_ok(tasks.into_iter());
|
||||
@@ -440,7 +475,7 @@ impl DnsClient {
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for server in &self.servers {
|
||||
let t = self.query_task(msg_buf.clone(), host, server);
|
||||
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
|
||||
tasks.push(Box::pin(t));
|
||||
}
|
||||
let query_task = select_ok(tasks.into_iter());
|
||||
@@ -453,7 +488,7 @@ impl DnsClient {
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for server in &self.servers {
|
||||
let t = self.query_task(msg_buf.clone(), host, server);
|
||||
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
|
||||
tasks.push(Box::pin(t));
|
||||
}
|
||||
let query_task = select_ok(tasks.into_iter());
|
||||
@@ -467,7 +502,7 @@ impl DnsClient {
|
||||
};
|
||||
let mut tasks = Vec::new();
|
||||
for server in &self.servers {
|
||||
let t = self.query_task(msg_buf.clone(), host, server);
|
||||
let t = self.query_task(is_direct, msg_buf.clone(), host, server);
|
||||
tasks.push(Box::pin(t));
|
||||
}
|
||||
let query_task = select_ok(tasks.into_iter());
|
||||
|
||||
@@ -23,7 +23,7 @@ impl Resolver {
|
||||
dns_client
|
||||
.read()
|
||||
.await
|
||||
.lookup(address)
|
||||
.direct_lookup(address)
|
||||
.map_err(|e| anyhow!("lookup {} failed: {}", address, e))
|
||||
.await?
|
||||
};
|
||||
|
||||
@@ -415,6 +415,16 @@ pub fn start(rt_id: RuntimeId, opts: StartOptions) -> Result<(), Error> {
|
||||
#[cfg(feature = "stat")]
|
||||
stat_manager.clone(),
|
||||
));
|
||||
|
||||
let dispatcher_weak = Arc::downgrade(&dispatcher);
|
||||
let dns_client_cloned = dns_client.clone();
|
||||
rt.block_on(async move {
|
||||
dns_client_cloned
|
||||
.write()
|
||||
.await
|
||||
.replace_dispatcher(dispatcher_weak);
|
||||
});
|
||||
|
||||
let nat_manager = Arc::new(NatManager::new(dispatcher.clone()));
|
||||
let inbound_manager =
|
||||
InboundManager::new(&config.inbounds, dispatcher, nat_manager).map_err(Error::Config)?;
|
||||
@@ -548,13 +558,15 @@ pub fn start(rt_id: RuntimeId, opts: StartOptions) -> Result<(), Error> {
|
||||
#[cfg(all(feature = "inbound-tun", any(target_os = "macos", target_os = "linux")))]
|
||||
sys::post_tun_completion_setup(&net_info);
|
||||
|
||||
rt.shutdown_background();
|
||||
drop(inbound_manager);
|
||||
|
||||
RUNTIME_MANAGER
|
||||
.lock()
|
||||
.map_err(|_| Error::RuntimeManager)?
|
||||
.remove(&rt_id);
|
||||
|
||||
rt.shutdown_background();
|
||||
|
||||
log::trace!("removed runtime {}", &rt_id);
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -15,6 +15,59 @@ use crate::{
|
||||
|
||||
use super::*;
|
||||
|
||||
/// An outbound datagram wraps a normal UDP socket and used as a normal UDP socket.
|
||||
pub struct StdOutboundDatagram {
|
||||
inner: UdpSocket,
|
||||
}
|
||||
|
||||
impl StdOutboundDatagram {
|
||||
pub fn new(inner: UdpSocket) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl OutboundDatagram for StdOutboundDatagram {
|
||||
fn split(
|
||||
self: Box<Self>,
|
||||
) -> (
|
||||
Box<dyn OutboundDatagramRecvHalf>,
|
||||
Box<dyn OutboundDatagramSendHalf>,
|
||||
) {
|
||||
let r = Arc::new(self.inner);
|
||||
let s = r.clone();
|
||||
(
|
||||
Box::new(StdOutboundDatagramRecvHalf(r)),
|
||||
Box::new(StdOutboundDatagramSendHalf(s)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StdOutboundDatagramRecvHalf(Arc<UdpSocket>);
|
||||
|
||||
#[async_trait]
|
||||
impl OutboundDatagramRecvHalf for StdOutboundDatagramRecvHalf {
|
||||
async fn recv_from(&mut self, buf: &mut [u8]) -> io::Result<(usize, SocksAddr)> {
|
||||
match self.0.recv_from(buf).await {
|
||||
Ok((n, a)) => Ok((n, SocksAddr::Ip(unmapped_ipv4(a)))),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StdOutboundDatagramSendHalf(Arc<UdpSocket>);
|
||||
|
||||
#[async_trait]
|
||||
impl OutboundDatagramSendHalf for StdOutboundDatagramSendHalf {
|
||||
async fn send_to(&mut self, buf: &[u8], target: &SocksAddr) -> io::Result<usize> {
|
||||
// The type does not accept domain name.
|
||||
self.0.send_to(buf, target.must_ip()).await
|
||||
}
|
||||
|
||||
async fn close(&mut self) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// An outbound datagram simply wraps a UDP socket.
|
||||
pub struct SimpleOutboundDatagram {
|
||||
inner: UdpSocket,
|
||||
|
||||
@@ -84,6 +84,7 @@ pub mod ws;
|
||||
pub use datagram::{
|
||||
SimpleInboundDatagram, SimpleInboundDatagramRecvHalf, SimpleInboundDatagramSendHalf,
|
||||
SimpleOutboundDatagram, SimpleOutboundDatagramRecvHalf, SimpleOutboundDatagramSendHalf,
|
||||
StdOutboundDatagram,
|
||||
};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
@@ -211,24 +212,20 @@ async fn bind_socket<T: BindSocket>(socket: &T, indicator: &SocketAddr) -> io::R
|
||||
}
|
||||
|
||||
let ret = match indicator {
|
||||
SocketAddr::V4(..) => {
|
||||
libc::setsockopt(
|
||||
socket.as_raw_fd(),
|
||||
libc::IPPROTO_IP,
|
||||
libc::IP_BOUND_IF,
|
||||
&ifidx as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
|
||||
)
|
||||
}
|
||||
SocketAddr::V6(..) => {
|
||||
libc::setsockopt(
|
||||
socket.as_raw_fd(),
|
||||
libc::IPPROTO_IPV6,
|
||||
libc::IPV6_BOUND_IF,
|
||||
&ifidx as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
|
||||
)
|
||||
}
|
||||
SocketAddr::V4(..) => libc::setsockopt(
|
||||
socket.as_raw_fd(),
|
||||
libc::IPPROTO_IP,
|
||||
libc::IP_BOUND_IF,
|
||||
&ifidx as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
|
||||
),
|
||||
SocketAddr::V6(..) => libc::setsockopt(
|
||||
socket.as_raw_fd(),
|
||||
libc::IPPROTO_IPV6,
|
||||
libc::IPV6_BOUND_IF,
|
||||
&ifidx as *const _ as *const libc::c_void,
|
||||
std::mem::size_of::<libc::c_uint>() as libc::socklen_t,
|
||||
),
|
||||
};
|
||||
if ret == -1 {
|
||||
last_err = Some(io::Error::last_os_error());
|
||||
|
||||
@@ -154,9 +154,9 @@ impl SocksAddr {
|
||||
Self::Ip("[::]:0".parse().unwrap())
|
||||
}
|
||||
|
||||
pub fn must_ip(self) -> SocketAddr {
|
||||
pub fn must_ip(&self) -> &SocketAddr {
|
||||
match self {
|
||||
SocksAddr::Ip(a) => a,
|
||||
SocksAddr::Ip(ref a) => a,
|
||||
_ => {
|
||||
panic!("assert SocksAddr as SocketAddr failed");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user