Refactor
This commit is contained in:
@@ -109,7 +109,7 @@ impl Dispatcher {
|
||||
|
||||
let outbound = {
|
||||
let router = self.router.read().await;
|
||||
let outbound = match router.pick_route(sess).await {
|
||||
match router.pick_route(sess).await {
|
||||
Ok(tag) => {
|
||||
debug!(
|
||||
"picked route [{}] for {} -> {}",
|
||||
@@ -136,8 +136,7 @@ impl Dispatcher {
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
outbound
|
||||
}
|
||||
};
|
||||
|
||||
let h = if let Some(h) = self.outbound_manager.read().await.get(&outbound) {
|
||||
@@ -233,7 +232,7 @@ impl Dispatcher {
|
||||
pub async fn dispatch_udp(&self, sess: &Session) -> io::Result<Box<dyn OutboundDatagram>> {
|
||||
let outbound = {
|
||||
let router = self.router.read().await;
|
||||
let outbound = match router.pick_route(sess).await {
|
||||
match router.pick_route(sess).await {
|
||||
Ok(tag) => {
|
||||
debug!(
|
||||
"picked route [{}] for {} -> {}",
|
||||
@@ -254,8 +253,7 @@ impl Dispatcher {
|
||||
return Err(io::Error::new(ErrorKind::Other, "no available handler"));
|
||||
}
|
||||
}
|
||||
};
|
||||
outbound
|
||||
}
|
||||
};
|
||||
|
||||
let h = if let Some(h) = self.outbound_manager.read().await.get(&outbound) {
|
||||
|
||||
@@ -27,36 +27,15 @@ async fn handle_inbound_datagram(
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(pkt) = client_ch_rx.recv().await {
|
||||
let dst_addr = match pkt.dst_addr {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
warn!("ignore udp pkt with unexpected empty dst addr");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let dst_addr = match dst_addr {
|
||||
SocksAddr::Ip(a) => a,
|
||||
_ => {
|
||||
error!("unexpected domain address");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let src_addr = match pkt.src_addr {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
warn!("ignore udp pkt with unexpected empty src addr");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let dst_addr = pkt.dst_addr.must_ip();
|
||||
if let Err(e) = client_sock_send
|
||||
.send_to(&pkt.data[..], Some(&src_addr), &dst_addr)
|
||||
.send_to(&pkt.data[..], Some(&pkt.src_addr), &dst_addr)
|
||||
.await
|
||||
{
|
||||
warn!("send udp pkt failed: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
debug!("udp downlink ended");
|
||||
});
|
||||
|
||||
let mut buf = [0u8; 2 * 1024];
|
||||
@@ -89,8 +68,8 @@ async fn handle_inbound_datagram(
|
||||
|
||||
let pkt = UdpPacket {
|
||||
data: (&buf[..n]).to_vec(),
|
||||
src_addr: Some(SocksAddr::from(dgram_src.address)),
|
||||
dst_addr: Some(dst_addr.clone()),
|
||||
src_addr: SocksAddr::from(dgram_src.address),
|
||||
dst_addr: dst_addr.clone(),
|
||||
};
|
||||
|
||||
nat_manager
|
||||
|
||||
@@ -16,21 +16,19 @@ use crate::session::{DatagramSource, Network, Session, SocksAddr};
|
||||
#[derive(Debug)]
|
||||
pub struct UdpPacket {
|
||||
pub data: Vec<u8>,
|
||||
pub src_addr: Option<SocksAddr>,
|
||||
pub dst_addr: Option<SocksAddr>,
|
||||
pub src_addr: SocksAddr,
|
||||
pub dst_addr: SocksAddr,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UdpPacket {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
let src = match self.src_addr {
|
||||
None => "None".to_string(),
|
||||
Some(ref addr) => addr.to_string(),
|
||||
};
|
||||
let dst = match self.dst_addr {
|
||||
None => "None".to_string(),
|
||||
Some(ref addr) => addr.to_string(),
|
||||
};
|
||||
write!(f, "{} <-> {}, {} bytes", src, dst, self.data.len())
|
||||
write!(
|
||||
f,
|
||||
"{} <-> {}, {} bytes",
|
||||
self.src_addr,
|
||||
self.dst_addr,
|
||||
self.data.len()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,27 +192,23 @@ impl NatManager {
|
||||
loop {
|
||||
match target_sock_recv.recv_from(&mut buf).await {
|
||||
Err(err) => {
|
||||
debug!("udp downlink error: {}", err);
|
||||
sessions.lock().await.remove(&raddr);
|
||||
break;
|
||||
}
|
||||
Ok((0, _)) => {
|
||||
debug!("receive zero-len udp packet");
|
||||
sessions.lock().await.remove(&raddr);
|
||||
debug!(
|
||||
"Failed to receive downlink packets on session {}: {}",
|
||||
&raddr, err
|
||||
);
|
||||
break;
|
||||
}
|
||||
Ok((n, addr)) => {
|
||||
let pkt = UdpPacket {
|
||||
data: (&buf[..n]).to_vec(),
|
||||
src_addr: Some(addr.clone()),
|
||||
dst_addr: Some(SocksAddr::from(raddr.address)),
|
||||
src_addr: addr.clone(),
|
||||
dst_addr: SocksAddr::from(raddr.address),
|
||||
};
|
||||
if let Err(err) = client_ch_tx.send(pkt).await {
|
||||
debug!(
|
||||
"send downlink packet failed {} -> {}: {}",
|
||||
&addr, &raddr, err
|
||||
"Failed to send downlink packets on session {} to {}: {}",
|
||||
&raddr, &addr, err
|
||||
);
|
||||
sessions.lock().await.remove(&raddr);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -237,6 +231,7 @@ impl NatManager {
|
||||
}
|
||||
}
|
||||
}
|
||||
sessions.lock().await.remove(&raddr);
|
||||
};
|
||||
|
||||
let (downlink_task, downlink_task_handle) = abortable(downlink_task);
|
||||
@@ -244,39 +239,18 @@ impl NatManager {
|
||||
|
||||
// Runs a task to receive the abort signal.
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = downlink_abort_rx.await {
|
||||
debug!(
|
||||
"failed to receive abort signal on session {}: {}",
|
||||
&raddr, e
|
||||
);
|
||||
};
|
||||
let _ = downlink_abort_rx.await;
|
||||
downlink_task_handle.abort();
|
||||
});
|
||||
|
||||
// uplink
|
||||
tokio::spawn(async move {
|
||||
while let Some(pkt) = target_ch_rx.recv().await {
|
||||
if pkt.dst_addr.is_none() {
|
||||
warn!("unexpected none dst addr in uplink pkts");
|
||||
continue;
|
||||
}
|
||||
let addr = match pkt.dst_addr {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
warn!("unexpected none addr");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
match target_sock_send.send_to(&pkt.data, &addr).await {
|
||||
Ok(0) => {
|
||||
debug!("uplink send zero bytes");
|
||||
}
|
||||
Ok(_) => {
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
debug!("uplink send error {:?}", err);
|
||||
}
|
||||
if let Err(e) = target_sock_send.send_to(&pkt.data, &pkt.dst_addr).await {
|
||||
debug!(
|
||||
"Failed to send uplink packets on session {} to {}: {:?}",
|
||||
&raddr, &pkt.dst_addr, e
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -34,7 +34,6 @@ use crate::{
|
||||
pub mod datagram;
|
||||
pub mod inbound;
|
||||
pub mod outbound;
|
||||
pub mod stream;
|
||||
|
||||
pub mod null;
|
||||
|
||||
@@ -85,7 +84,6 @@ pub use datagram::{
|
||||
SimpleInboundDatagram, SimpleInboundDatagramRecvHalf, SimpleInboundDatagramSendHalf,
|
||||
SimpleOutboundDatagram, SimpleOutboundDatagramRecvHalf, SimpleOutboundDatagramSendHalf,
|
||||
};
|
||||
pub use stream::BufHeadProxyStream;
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub enum DatagramTransportType {
|
||||
|
||||
@@ -35,7 +35,6 @@ impl TcpOutboundHandler for Handler {
|
||||
let mut buf = BytesMut::new();
|
||||
sess.destination
|
||||
.write_buf(&mut buf, SocksAddrWireType::PortLast)?;
|
||||
// FIXME combine header and first payload
|
||||
stream.write_all(&buf).await?;
|
||||
Ok(Box::new(stream))
|
||||
}
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
use std::{io, pin::Pin};
|
||||
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use futures::ready;
|
||||
use futures::task::{Context, Poll};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
|
||||
/// A proxy stream writes a header along with the first payload.
|
||||
pub struct BufHeadProxyStream<T> {
|
||||
inner: T,
|
||||
head: Option<Bytes>,
|
||||
first_payload: BytesMut,
|
||||
}
|
||||
|
||||
impl<T> BufHeadProxyStream<T> {
|
||||
pub fn new(inner: T, head: Bytes) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
head: Some(head),
|
||||
first_payload: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AsyncRead for BufHeadProxyStream<T>
|
||||
where
|
||||
T: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
fn early_eof() -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Interrupted, "early eof")
|
||||
}
|
||||
|
||||
impl<T> AsyncWrite for BufHeadProxyStream<T>
|
||||
where
|
||||
T: AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let me = &mut *self;
|
||||
// Combine header and first payload.
|
||||
if let Some(head) = me.head.take() {
|
||||
me.first_payload.put_slice(&head);
|
||||
me.first_payload.put_slice(buf);
|
||||
}
|
||||
while !me.first_payload.is_empty() {
|
||||
let n = ready!(Pin::new(&mut me.inner).poll_write(cx, &me.first_payload))?;
|
||||
if n == 0 {
|
||||
return Poll::Ready(Err(early_eof()));
|
||||
}
|
||||
me.first_payload.advance(n);
|
||||
if me.first_payload.is_empty() {
|
||||
me.first_payload = BytesMut::new(); // shadow to free
|
||||
return Poll::Ready(Ok(buf.len()));
|
||||
}
|
||||
}
|
||||
Pin::new(&mut me.inner).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.inner).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
@@ -106,6 +106,7 @@ pub fn new(
|
||||
|
||||
let mut futs: Vec<Runner> = Vec::new();
|
||||
|
||||
// Reads packet from stack and sends to TUN.
|
||||
let s2t = Box::pin(async move {
|
||||
while let Some(pkt) = stack_stream.next().await {
|
||||
if let Ok(pkt) = pkt {
|
||||
@@ -115,6 +116,7 @@ pub fn new(
|
||||
});
|
||||
futs.push(s2t);
|
||||
|
||||
// Reads packet from TUN and sends to stack.
|
||||
let t2s = Box::pin(async move {
|
||||
while let Some(pkt) = tun_stream.next().await {
|
||||
if let Ok(pkt) = pkt {
|
||||
@@ -124,6 +126,7 @@ pub fn new(
|
||||
});
|
||||
futs.push(t2s);
|
||||
|
||||
// Extracts TCP connections from stack and sends them to the dispatcher.
|
||||
let fakedns_cloned = fakedns.clone();
|
||||
let lwip_mutex_cloned = lwip_mutex.clone();
|
||||
let inbound_tag_cloned = inbound_tag.clone();
|
||||
@@ -173,12 +176,12 @@ pub fn new(
|
||||
});
|
||||
futs.push(tcp_incoming);
|
||||
|
||||
let nat_manager = nat_manager.clone();
|
||||
let fakedns = fakedns.clone();
|
||||
// Extracts UDP packets from stack and sends them to the NAT manager, which would
|
||||
// maintain UDP sessions and send them to the dispatcher.
|
||||
let udp_incoming = Box::pin(async move {
|
||||
let mut listener = netstack::UdpListener::new();
|
||||
let nat_manager = nat_manager.clone();
|
||||
let fakedns2 = fakedns.clone();
|
||||
let fakedns_cloned = fakedns.clone();
|
||||
let pcb = listener.pcb();
|
||||
|
||||
// Sending packets to TUN should be very fast.
|
||||
@@ -191,27 +194,8 @@ pub fn new(
|
||||
let lwip_mutex_cloned = lwip_mutex.clone();
|
||||
tokio::spawn(async move {
|
||||
while let Some(pkt) = client_ch_rx.recv().await {
|
||||
let socks_src_addr = match pkt.src_addr {
|
||||
Some(a) => a,
|
||||
None => {
|
||||
warn!("unexpected none src addr");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let dst_addr = match pkt.dst_addr {
|
||||
Some(a) => match a {
|
||||
SocksAddr::Ip(a) => a,
|
||||
_ => {
|
||||
warn!("unexpected domain addr");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
warn!("unexpected dst addr");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let src_addr = match socks_src_addr {
|
||||
let dst_addr = pkt.dst_addr.must_ip();
|
||||
let src_addr = match pkt.src_addr {
|
||||
SocksAddr::Ip(a) => a,
|
||||
|
||||
// If the socket gives us a domain source address,
|
||||
@@ -220,7 +204,7 @@ pub fn new(
|
||||
SocksAddr::Domain(domain, port) => {
|
||||
// TODO we're doing this for every packet! optimize needed
|
||||
// trace!("downlink querying fake ip for domain {}", &domain);
|
||||
if let Some(ip) = fakedns2.lock().await.query_fake_ip(&domain) {
|
||||
if let Some(ip) = fakedns_cloned.lock().await.query_fake_ip(&domain) {
|
||||
SocketAddr::new(ip, port)
|
||||
} else {
|
||||
warn!(
|
||||
@@ -239,15 +223,11 @@ pub fn new(
|
||||
&pkt.data[..],
|
||||
);
|
||||
}
|
||||
|
||||
error!("unexpected udp downlink ended");
|
||||
});
|
||||
|
||||
let fakedns2 = fakedns.clone();
|
||||
|
||||
while let Some(pkt) = listener.next().await {
|
||||
if pkt.2.port() == 53 {
|
||||
match fakedns2.lock().await.generate_fake_response(&pkt.0) {
|
||||
match fakedns.lock().await.generate_fake_response(&pkt.0) {
|
||||
Ok(resp) => {
|
||||
netstack::send_udp(
|
||||
lwip_mutex.clone(),
|
||||
@@ -268,10 +248,10 @@ pub fn new(
|
||||
// that said, the application connects a UDP socket with a domain address.
|
||||
// It also means the back packets on this UDP session shall only come from a
|
||||
// single source address.
|
||||
let socks_dst_addr = if fakedns2.lock().await.is_fake_ip(&pkt.2.ip()) {
|
||||
let socks_dst_addr = if fakedns.lock().await.is_fake_ip(&pkt.2.ip()) {
|
||||
// TODO we're doing this for every packet! optimize needed
|
||||
// trace!("uplink querying domain for fake ip {}", &dst_addr.ip(),);
|
||||
if let Some(domain) = fakedns2.lock().await.query_domain(&pkt.2.ip()) {
|
||||
if let Some(domain) = fakedns.lock().await.query_domain(&pkt.2.ip()) {
|
||||
SocksAddr::Domain(domain, pkt.2.port())
|
||||
} else {
|
||||
// Skip this packet. Requests targeting fake IPs are
|
||||
@@ -286,8 +266,8 @@ pub fn new(
|
||||
|
||||
let pkt = UdpPacket {
|
||||
data: pkt.0,
|
||||
src_addr: Some(SocksAddr::Ip(dgram_src.address)),
|
||||
dst_addr: Some(socks_dst_addr.clone()),
|
||||
src_addr: SocksAddr::Ip(dgram_src.address),
|
||||
dst_addr: socks_dst_addr.clone(),
|
||||
};
|
||||
|
||||
nat_manager
|
||||
|
||||
@@ -7,7 +7,6 @@ use std::{
|
||||
|
||||
use byteorder::{BigEndian, ByteOrder};
|
||||
use bytes::BufMut;
|
||||
use log::*;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
|
||||
#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)]
|
||||
@@ -146,8 +145,7 @@ impl SocksAddr {
|
||||
match self {
|
||||
SocksAddr::Ip(a) => a,
|
||||
_ => {
|
||||
error!("assert SocksAddr as SocketAddr failed");
|
||||
panic!("");
|
||||
panic!("assert SocksAddr as SocketAddr failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user