Dial MPTP sub streams concurrently

This commit is contained in:
eric
2026-02-24 19:41:46 +08:00
parent 0ff4794aa5
commit 760820ac52
2 changed files with 79 additions and 37 deletions

View File

@@ -83,6 +83,26 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
}
}
pub fn new_with_receiver_and_initial(
streams: Vec<S>,
cid: uuid::Uuid,
rx: mpsc::UnboundedReceiver<(S, Option<uuid::Uuid>)>,
) -> Self {
let subs = streams
.into_iter()
.map(|s| SubConnection::new_with_cid(s, cid))
.collect();
Self {
subs,
new_subs_rx: Some(rx),
read_buffer: BytesMut::new(),
next_pn: 1,
expected_read_pn: 1,
reorder_buffer: BTreeMap::new(),
closed: false,
}
}
fn poll_new_subs(&mut self, cx: &mut Context<'_>) {
if let Some(rx) = &mut self.new_subs_rx {
loop {

View File

@@ -7,6 +7,7 @@ use crate::proxy::mptp::mptp_conn::{MptpDatagram, MptpStream};
use async_trait::async_trait;
use bytes::BytesMut;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc;
use uuid::Uuid;
use crate::{
@@ -32,14 +33,14 @@ impl Handler {
sess: &Session,
cmd: u8,
) -> io::Result<MptpStream<AnyStream>> {
let mut sub_streams = Vec::new();
let (tx, mut rx) = mpsc::unbounded_channel();
let cid = Uuid::new_v4();
// Clone session and set destination to MPTP server
let mut server_sess = sess.clone();
server_sess.destination = SocksAddr::try_from((self.address.clone(), self.port))
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
let cid = Uuid::new_v4();
let target_addr = match &sess.destination {
SocksAddr::Ip(addr) => match addr {
std::net::SocketAddr::V4(v4) => Address::Ipv4(*v4.ip()),
@@ -49,54 +50,75 @@ impl Handler {
};
let target_port = sess.destination.port();
// 1. Establish sub-connections
// 1. Establish sub-connections in parallel
for (i, actor) in self.actors.iter().enumerate() {
if let Ok(stream_handler) = actor.stream() {
// Try to connect if the actor requires a connection
let stream_opt: Option<AnyStream> =
connect_stream_outbound(&server_sess, self.dns_client.clone(), actor)
.await
.map_err(|e| {
io::Error::new(io::ErrorKind::Other, format!("connect failed: {}", e))
})?;
let actor = actor.clone();
let server_sess = server_sess.clone();
let dns_client = self.dns_client.clone();
let tx = tx.clone();
let target_addr = target_addr.clone();
let cid = cid;
let cmd = cmd;
let target_port = target_port;
match stream_handler.handle(&server_sess, None, stream_opt).await {
Ok(mut stream) => {
// 2. Perform Handshake on each sub-connection
let req = HandshakeRequest {
ver: VER,
cid,
cmd,
dst_addr: target_addr.clone(),
dst_port: target_port,
tokio::spawn(async move {
if let Ok(stream_handler) = actor.stream() {
// Try to connect if the actor requires a connection
let stream_opt =
match connect_stream_outbound(&server_sess, dns_client, &actor).await {
Ok(opt) => opt,
Err(e) => {
tracing::warn!("Failed to connect sub {}: {}", i, e);
return;
}
};
let mut buf = BytesMut::new();
req.encode(&mut buf);
match stream_handler.handle(&server_sess, None, stream_opt).await {
Ok(mut stream) => {
// 2. Perform Handshake on each sub-connection
let req = HandshakeRequest {
ver: VER,
cid,
cmd,
dst_addr: target_addr,
dst_port: target_port,
};
if let Err(e) = stream.write_all(&buf).await {
tracing::warn!("Failed to send handshake for sub {}: {}", i, e);
continue;
let mut buf = BytesMut::new();
req.encode(&mut buf);
if let Err(e) = stream.write_all(&buf).await {
tracing::warn!("Failed to send handshake for sub {}: {}", i, e);
return;
}
let _ = tx.send((stream, Some(cid)));
}
Err(e) => {
tracing::warn!("Failed to handle sub {}: {}", i, e);
}
sub_streams.push(stream);
}
Err(e) => {
tracing::warn!("Failed to connect sub {}: {}", i, e);
}
}
}
});
}
if sub_streams.is_empty() {
return Err(io::Error::new(
// Drop our local tx so that rx.recv() returns None when all tasks finish
drop(tx);
// Wait for the first successful connection
if let Some((first_stream, _)) = rx.recv().await {
// 3. Create MptpStream with the first stream and the receiver for subsequent ones
Ok(MptpStream::new_with_receiver_and_initial(
vec![first_stream],
cid,
rx,
))
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"No available sub-connections",
));
))
}
// 3. Create MptpStream
Ok(MptpStream::new(sub_streams, cid))
}
}