This commit is contained in:
eric
2022-04-10 15:28:09 +08:00
parent 796c22dd84
commit 38ebca111d
8 changed files with 47 additions and 199 deletions

View File

@@ -109,7 +109,7 @@ impl Dispatcher {
let outbound = {
let router = self.router.read().await;
let outbound = match router.pick_route(sess).await {
match router.pick_route(sess).await {
Ok(tag) => {
debug!(
"picked route [{}] for {} -> {}",
@@ -136,8 +136,7 @@ impl Dispatcher {
return;
}
}
};
outbound
}
};
let h = if let Some(h) = self.outbound_manager.read().await.get(&outbound) {
@@ -233,7 +232,7 @@ impl Dispatcher {
pub async fn dispatch_udp(&self, sess: &Session) -> io::Result<Box<dyn OutboundDatagram>> {
let outbound = {
let router = self.router.read().await;
let outbound = match router.pick_route(sess).await {
match router.pick_route(sess).await {
Ok(tag) => {
debug!(
"picked route [{}] for {} -> {}",
@@ -254,8 +253,7 @@ impl Dispatcher {
return Err(io::Error::new(ErrorKind::Other, "no available handler"));
}
}
};
outbound
}
};
let h = if let Some(h) = self.outbound_manager.read().await.get(&outbound) {

View File

@@ -27,36 +27,15 @@ async fn handle_inbound_datagram(
tokio::spawn(async move {
while let Some(pkt) = client_ch_rx.recv().await {
let dst_addr = match pkt.dst_addr {
Some(a) => a,
None => {
warn!("ignore udp pkt with unexpected empty dst addr");
continue;
}
};
let dst_addr = match dst_addr {
SocksAddr::Ip(a) => a,
_ => {
error!("unexpected domain address");
continue;
}
};
let src_addr = match pkt.src_addr {
Some(a) => a,
None => {
warn!("ignore udp pkt with unexpected empty src addr");
continue;
}
};
let dst_addr = pkt.dst_addr.must_ip();
if let Err(e) = client_sock_send
.send_to(&pkt.data[..], Some(&src_addr), &dst_addr)
.send_to(&pkt.data[..], Some(&pkt.src_addr), &dst_addr)
.await
{
warn!("send udp pkt failed: {}", e);
return;
}
}
debug!("udp downlink ended");
});
let mut buf = [0u8; 2 * 1024];
@@ -89,8 +68,8 @@ async fn handle_inbound_datagram(
let pkt = UdpPacket {
data: (&buf[..n]).to_vec(),
src_addr: Some(SocksAddr::from(dgram_src.address)),
dst_addr: Some(dst_addr.clone()),
src_addr: SocksAddr::from(dgram_src.address),
dst_addr: dst_addr.clone(),
};
nat_manager

View File

@@ -16,21 +16,19 @@ use crate::session::{DatagramSource, Network, Session, SocksAddr};
#[derive(Debug)]
pub struct UdpPacket {
pub data: Vec<u8>,
pub src_addr: Option<SocksAddr>,
pub dst_addr: Option<SocksAddr>,
pub src_addr: SocksAddr,
pub dst_addr: SocksAddr,
}
impl std::fmt::Display for UdpPacket {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let src = match self.src_addr {
None => "None".to_string(),
Some(ref addr) => addr.to_string(),
};
let dst = match self.dst_addr {
None => "None".to_string(),
Some(ref addr) => addr.to_string(),
};
write!(f, "{} <-> {}, {} bytes", src, dst, self.data.len())
write!(
f,
"{} <-> {}, {} bytes",
self.src_addr,
self.dst_addr,
self.data.len()
)
}
}
@@ -194,27 +192,23 @@ impl NatManager {
loop {
match target_sock_recv.recv_from(&mut buf).await {
Err(err) => {
debug!("udp downlink error: {}", err);
sessions.lock().await.remove(&raddr);
break;
}
Ok((0, _)) => {
debug!("receive zero-len udp packet");
sessions.lock().await.remove(&raddr);
debug!(
"Failed to receive downlink packets on session {}: {}",
&raddr, err
);
break;
}
Ok((n, addr)) => {
let pkt = UdpPacket {
data: (&buf[..n]).to_vec(),
src_addr: Some(addr.clone()),
dst_addr: Some(SocksAddr::from(raddr.address)),
src_addr: addr.clone(),
dst_addr: SocksAddr::from(raddr.address),
};
if let Err(err) = client_ch_tx.send(pkt).await {
debug!(
"send downlink packet failed {} -> {}: {}",
&addr, &raddr, err
"Failed to send downlink packets on session {} to {}: {}",
&raddr, &addr, err
);
sessions.lock().await.remove(&raddr);
break;
}
@@ -237,6 +231,7 @@ impl NatManager {
}
}
}
sessions.lock().await.remove(&raddr);
};
let (downlink_task, downlink_task_handle) = abortable(downlink_task);
@@ -244,39 +239,18 @@ impl NatManager {
// Runs a task to receive the abort signal.
tokio::spawn(async move {
if let Err(e) = downlink_abort_rx.await {
debug!(
"failed to receive abort signal on session {}: {}",
&raddr, e
);
};
let _ = downlink_abort_rx.await;
downlink_task_handle.abort();
});
// uplink
tokio::spawn(async move {
while let Some(pkt) = target_ch_rx.recv().await {
if pkt.dst_addr.is_none() {
warn!("unexpected none dst addr in uplink pkts");
continue;
}
let addr = match pkt.dst_addr {
Some(a) => a,
None => {
warn!("unexpected none addr");
continue;
}
};
match target_sock_send.send_to(&pkt.data, &addr).await {
Ok(0) => {
debug!("uplink send zero bytes");
}
Ok(_) => {
continue;
}
Err(err) => {
debug!("uplink send error {:?}", err);
}
if let Err(e) = target_sock_send.send_to(&pkt.data, &pkt.dst_addr).await {
debug!(
"Failed to send uplink packets on session {} to {}: {:?}",
&raddr, &pkt.dst_addr, e
);
}
}
});

View File

@@ -34,7 +34,6 @@ use crate::{
pub mod datagram;
pub mod inbound;
pub mod outbound;
pub mod stream;
pub mod null;
@@ -85,7 +84,6 @@ pub use datagram::{
SimpleInboundDatagram, SimpleInboundDatagramRecvHalf, SimpleInboundDatagramSendHalf,
SimpleOutboundDatagram, SimpleOutboundDatagramRecvHalf, SimpleOutboundDatagramSendHalf,
};
pub use stream::BufHeadProxyStream;
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum DatagramTransportType {

View File

@@ -35,7 +35,6 @@ impl TcpOutboundHandler for Handler {
let mut buf = BytesMut::new();
sess.destination
.write_buf(&mut buf, SocksAddrWireType::PortLast)?;
// FIXME combine header and first payload
stream.write_all(&buf).await?;
Ok(Box::new(stream))
}

View File

@@ -1,78 +0,0 @@
use std::{io, pin::Pin};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::ready;
use futures::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
/// A proxy stream writes a header along with the first payload.
pub struct BufHeadProxyStream<T> {
inner: T,
head: Option<Bytes>,
first_payload: BytesMut,
}
impl<T> BufHeadProxyStream<T> {
pub fn new(inner: T, head: Bytes) -> Self {
Self {
inner,
head: Some(head),
first_payload: BytesMut::new(),
}
}
}
impl<T> AsyncRead for BufHeadProxyStream<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
fn early_eof() -> io::Error {
io::Error::new(io::ErrorKind::Interrupted, "early eof")
}
impl<T> AsyncWrite for BufHeadProxyStream<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let me = &mut *self;
// Combine header and first payload.
if let Some(head) = me.head.take() {
me.first_payload.put_slice(&head);
me.first_payload.put_slice(buf);
}
while !me.first_payload.is_empty() {
let n = ready!(Pin::new(&mut me.inner).poll_write(cx, &me.first_payload))?;
if n == 0 {
return Poll::Ready(Err(early_eof()));
}
me.first_payload.advance(n);
if me.first_payload.is_empty() {
me.first_payload = BytesMut::new(); // shadow to free
return Poll::Ready(Ok(buf.len()));
}
}
Pin::new(&mut me.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

View File

@@ -106,6 +106,7 @@ pub fn new(
let mut futs: Vec<Runner> = Vec::new();
// Reads packet from stack and sends to TUN.
let s2t = Box::pin(async move {
while let Some(pkt) = stack_stream.next().await {
if let Ok(pkt) = pkt {
@@ -115,6 +116,7 @@ pub fn new(
});
futs.push(s2t);
// Reads packet from TUN and sends to stack.
let t2s = Box::pin(async move {
while let Some(pkt) = tun_stream.next().await {
if let Ok(pkt) = pkt {
@@ -124,6 +126,7 @@ pub fn new(
});
futs.push(t2s);
// Extracts TCP connections from stack and sends them to the dispatcher.
let fakedns_cloned = fakedns.clone();
let lwip_mutex_cloned = lwip_mutex.clone();
let inbound_tag_cloned = inbound_tag.clone();
@@ -173,12 +176,12 @@ pub fn new(
});
futs.push(tcp_incoming);
let nat_manager = nat_manager.clone();
let fakedns = fakedns.clone();
// Extracts UDP packets from stack and sends them to the NAT manager, which would
// maintain UDP sessions and send them to the dispatcher.
let udp_incoming = Box::pin(async move {
let mut listener = netstack::UdpListener::new();
let nat_manager = nat_manager.clone();
let fakedns2 = fakedns.clone();
let fakedns_cloned = fakedns.clone();
let pcb = listener.pcb();
// Sending packets to TUN should be very fast.
@@ -191,27 +194,8 @@ pub fn new(
let lwip_mutex_cloned = lwip_mutex.clone();
tokio::spawn(async move {
while let Some(pkt) = client_ch_rx.recv().await {
let socks_src_addr = match pkt.src_addr {
Some(a) => a,
None => {
warn!("unexpected none src addr");
continue;
}
};
let dst_addr = match pkt.dst_addr {
Some(a) => match a {
SocksAddr::Ip(a) => a,
_ => {
warn!("unexpected domain addr");
continue;
}
},
None => {
warn!("unexpected dst addr");
continue;
}
};
let src_addr = match socks_src_addr {
let dst_addr = pkt.dst_addr.must_ip();
let src_addr = match pkt.src_addr {
SocksAddr::Ip(a) => a,
// If the socket gives us a domain source address,
@@ -220,7 +204,7 @@ pub fn new(
SocksAddr::Domain(domain, port) => {
// TODO we're doing this for every packet! optimize needed
// trace!("downlink querying fake ip for domain {}", &domain);
if let Some(ip) = fakedns2.lock().await.query_fake_ip(&domain) {
if let Some(ip) = fakedns_cloned.lock().await.query_fake_ip(&domain) {
SocketAddr::new(ip, port)
} else {
warn!(
@@ -239,15 +223,11 @@ pub fn new(
&pkt.data[..],
);
}
error!("unexpected udp downlink ended");
});
let fakedns2 = fakedns.clone();
while let Some(pkt) = listener.next().await {
if pkt.2.port() == 53 {
match fakedns2.lock().await.generate_fake_response(&pkt.0) {
match fakedns.lock().await.generate_fake_response(&pkt.0) {
Ok(resp) => {
netstack::send_udp(
lwip_mutex.clone(),
@@ -268,10 +248,10 @@ pub fn new(
// that said, the application connects a UDP socket with a domain address.
// It also means the back packets on this UDP session shall only come from a
// single source address.
let socks_dst_addr = if fakedns2.lock().await.is_fake_ip(&pkt.2.ip()) {
let socks_dst_addr = if fakedns.lock().await.is_fake_ip(&pkt.2.ip()) {
// TODO we're doing this for every packet! optimize needed
// trace!("uplink querying domain for fake ip {}", &dst_addr.ip(),);
if let Some(domain) = fakedns2.lock().await.query_domain(&pkt.2.ip()) {
if let Some(domain) = fakedns.lock().await.query_domain(&pkt.2.ip()) {
SocksAddr::Domain(domain, pkt.2.port())
} else {
// Skip this packet. Requests targeting fake IPs are
@@ -286,8 +266,8 @@ pub fn new(
let pkt = UdpPacket {
data: pkt.0,
src_addr: Some(SocksAddr::Ip(dgram_src.address)),
dst_addr: Some(socks_dst_addr.clone()),
src_addr: SocksAddr::Ip(dgram_src.address),
dst_addr: socks_dst_addr.clone(),
};
nat_manager

View File

@@ -7,7 +7,6 @@ use std::{
use byteorder::{BigEndian, ByteOrder};
use bytes::BufMut;
use log::*;
use tokio::io::{AsyncRead, AsyncReadExt};
#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)]
@@ -146,8 +145,7 @@ impl SocksAddr {
match self {
SocksAddr::Ip(a) => a,
_ => {
error!("assert SocksAddr as SocketAddr failed");
panic!("");
panic!("assert SocksAddr as SocketAddr failed");
}
}
}