Dial MPTP sub streams concurrently
This commit is contained in:
@@ -83,6 +83,26 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_with_receiver_and_initial(
|
||||
streams: Vec<S>,
|
||||
cid: uuid::Uuid,
|
||||
rx: mpsc::UnboundedReceiver<(S, Option<uuid::Uuid>)>,
|
||||
) -> Self {
|
||||
let subs = streams
|
||||
.into_iter()
|
||||
.map(|s| SubConnection::new_with_cid(s, cid))
|
||||
.collect();
|
||||
Self {
|
||||
subs,
|
||||
new_subs_rx: Some(rx),
|
||||
read_buffer: BytesMut::new(),
|
||||
next_pn: 1,
|
||||
expected_read_pn: 1,
|
||||
reorder_buffer: BTreeMap::new(),
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_new_subs(&mut self, cx: &mut Context<'_>) {
|
||||
if let Some(rx) = &mut self.new_subs_rx {
|
||||
loop {
|
||||
|
||||
@@ -7,6 +7,7 @@ use crate::proxy::mptp::mptp_conn::{MptpDatagram, MptpStream};
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
@@ -32,14 +33,14 @@ impl Handler {
|
||||
sess: &Session,
|
||||
cmd: u8,
|
||||
) -> io::Result<MptpStream<AnyStream>> {
|
||||
let mut sub_streams = Vec::new();
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
let cid = Uuid::new_v4();
|
||||
|
||||
// Clone session and set destination to MPTP server
|
||||
let mut server_sess = sess.clone();
|
||||
server_sess.destination = SocksAddr::try_from((self.address.clone(), self.port))
|
||||
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
|
||||
|
||||
let cid = Uuid::new_v4();
|
||||
let target_addr = match &sess.destination {
|
||||
SocksAddr::Ip(addr) => match addr {
|
||||
std::net::SocketAddr::V4(v4) => Address::Ipv4(*v4.ip()),
|
||||
@@ -49,54 +50,75 @@ impl Handler {
|
||||
};
|
||||
let target_port = sess.destination.port();
|
||||
|
||||
// 1. Establish sub-connections
|
||||
// 1. Establish sub-connections in parallel
|
||||
for (i, actor) in self.actors.iter().enumerate() {
|
||||
if let Ok(stream_handler) = actor.stream() {
|
||||
// Try to connect if the actor requires a connection
|
||||
let stream_opt: Option<AnyStream> =
|
||||
connect_stream_outbound(&server_sess, self.dns_client.clone(), actor)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
io::Error::new(io::ErrorKind::Other, format!("connect failed: {}", e))
|
||||
})?;
|
||||
let actor = actor.clone();
|
||||
let server_sess = server_sess.clone();
|
||||
let dns_client = self.dns_client.clone();
|
||||
let tx = tx.clone();
|
||||
let target_addr = target_addr.clone();
|
||||
let cid = cid;
|
||||
let cmd = cmd;
|
||||
let target_port = target_port;
|
||||
|
||||
match stream_handler.handle(&server_sess, None, stream_opt).await {
|
||||
Ok(mut stream) => {
|
||||
// 2. Perform Handshake on each sub-connection
|
||||
let req = HandshakeRequest {
|
||||
ver: VER,
|
||||
cid,
|
||||
cmd,
|
||||
dst_addr: target_addr.clone(),
|
||||
dst_port: target_port,
|
||||
tokio::spawn(async move {
|
||||
if let Ok(stream_handler) = actor.stream() {
|
||||
// Try to connect if the actor requires a connection
|
||||
let stream_opt =
|
||||
match connect_stream_outbound(&server_sess, dns_client, &actor).await {
|
||||
Ok(opt) => opt,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to connect sub {}: {}", i, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
req.encode(&mut buf);
|
||||
match stream_handler.handle(&server_sess, None, stream_opt).await {
|
||||
Ok(mut stream) => {
|
||||
// 2. Perform Handshake on each sub-connection
|
||||
let req = HandshakeRequest {
|
||||
ver: VER,
|
||||
cid,
|
||||
cmd,
|
||||
dst_addr: target_addr,
|
||||
dst_port: target_port,
|
||||
};
|
||||
|
||||
if let Err(e) = stream.write_all(&buf).await {
|
||||
tracing::warn!("Failed to send handshake for sub {}: {}", i, e);
|
||||
continue;
|
||||
let mut buf = BytesMut::new();
|
||||
req.encode(&mut buf);
|
||||
|
||||
if let Err(e) = stream.write_all(&buf).await {
|
||||
tracing::warn!("Failed to send handshake for sub {}: {}", i, e);
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = tx.send((stream, Some(cid)));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to handle sub {}: {}", i, e);
|
||||
}
|
||||
|
||||
sub_streams.push(stream);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to connect sub {}: {}", i, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if sub_streams.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
// Drop our local tx so that rx.recv() returns None when all tasks finish
|
||||
drop(tx);
|
||||
|
||||
// Wait for the first successful connection
|
||||
if let Some((first_stream, _)) = rx.recv().await {
|
||||
// 3. Create MptpStream with the first stream and the receiver for subsequent ones
|
||||
Ok(MptpStream::new_with_receiver_and_initial(
|
||||
vec![first_stream],
|
||||
cid,
|
||||
rx,
|
||||
))
|
||||
} else {
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
"No available sub-connections",
|
||||
));
|
||||
))
|
||||
}
|
||||
|
||||
// 3. Create MptpStream
|
||||
Ok(MptpStream::new(sub_streams, cid))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user