This commit is contained in:
eric
2026-02-27 19:44:27 +08:00
parent 4cb248086c
commit be043cd96d
3 changed files with 67 additions and 58 deletions

View File

@@ -281,3 +281,34 @@ impl Frame {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_encoding() {
let req = HandshakeRequest {
ver: VER,
cid: uuid::Uuid::new_v4(),
cmd: CMD_CONNECT,
dst_addr: Address::Ipv4("127.0.0.1".parse().unwrap()),
dst_port: 8080,
};
let mut buf = BytesMut::new();
req.encode(&mut buf);
let mut decode_buf = buf.clone();
let decoded = HandshakeRequest::decode(&mut decode_buf).unwrap().unwrap();
assert_eq!(req.ver, decoded.ver);
assert_eq!(req.cid, decoded.cid);
assert_eq!(req.cmd, decoded.cmd);
assert_eq!(req.dst_port, decoded.dst_port);
match (req.dst_addr, decoded.dst_addr) {
(Address::Ipv4(a), Address::Ipv4(b)) => assert_eq!(a, b),
_ => panic!("Address mismatch"),
}
}
}

View File

@@ -15,8 +15,7 @@ struct SubConnection<S> {
stream: S,
read_buf: BytesMut,
write_buf: BytesMut,
closed: bool, // Mark if this sub-connection is dead
cid: uuid::Uuid, // Add CID for debugging
closed: bool, // Mark if this sub-connection is dead
}
impl<S> SubConnection<S> {
@@ -26,17 +25,6 @@ impl<S> SubConnection<S> {
read_buf: BytesMut::with_capacity(4096),
write_buf: BytesMut::with_capacity(4096),
closed: false,
cid: uuid::Uuid::nil(), // Default or placeholder
}
}
fn new_with_cid(stream: S, cid: uuid::Uuid) -> Self {
Self {
stream,
read_buf: BytesMut::with_capacity(4096),
write_buf: BytesMut::with_capacity(4096),
closed: false,
cid,
}
}
}
@@ -55,11 +43,8 @@ pub struct MptpStream<S> {
}
impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
pub fn new(streams: Vec<S>, cid: uuid::Uuid) -> Self {
let subs = streams
.into_iter()
.map(|s| SubConnection::new_with_cid(s, cid))
.collect();
pub fn new(streams: Vec<S>) -> Self {
let subs = streams.into_iter().map(|s| SubConnection::new(s)).collect();
Self {
subs,
new_subs_rx: None,
@@ -85,13 +70,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
pub fn new_with_receiver_and_initial(
streams: Vec<S>,
cid: uuid::Uuid,
rx: mpsc::UnboundedReceiver<(S, Option<uuid::Uuid>)>,
) -> Self {
let subs = streams
.into_iter()
.map(|s| SubConnection::new_with_cid(s, cid))
.collect();
let subs = streams.into_iter().map(|s| SubConnection::new(s)).collect();
Self {
subs,
new_subs_rx: Some(rx),
@@ -107,12 +88,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
if let Some(rx) = &mut self.new_subs_rx {
loop {
match rx.poll_recv(cx) {
Poll::Ready(Some((stream, cid_opt))) => {
if let Some(cid) = cid_opt {
self.subs.push(SubConnection::new_with_cid(stream, cid));
} else {
self.subs.push(SubConnection::new(stream));
}
Poll::Ready(Some((stream, _))) => {
self.subs.push(SubConnection::new(stream));
}
Poll::Ready(None) => {
// Channel closed, no more new subs
@@ -179,19 +156,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
all_eof = false;
} else {
// EOF for this sub
warn!("Sub {} (CID={}) EOF, marking closed", i, sub.cid);
warn!("Sub {} EOF, marking closed", i);
sub.closed = true;
// Don't return EOF yet, others might be alive
}
}
Poll::Ready(Err(e)) => {
if e.kind() == io::ErrorKind::UnexpectedEof {
warn!("Sub {} (CID={}) UnexpectedEof, marking closed", i, sub.cid);
warn!("Sub {} UnexpectedEof, marking closed", i);
} else {
error!(
"Sub {} (CID={}) read error: {}, marking closed",
i, sub.cid, e
);
error!("Sub {} read error: {}, marking closed", i, e);
}
sub.closed = true;
// Don't return error, just close this sub
@@ -271,10 +245,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
if pn < this.expected_read_pn {
// Duplicate or old packet, ignore
// log::trace!("Duplicate PN: {} (expected {})", pn, this.expected_read_pn);
} else if pn == this.expected_read_pn {
// Expected packet
// log::trace!("Received expected PN: {}", pn);
this.read_buffer.extend_from_slice(&payload);
this.expected_read_pn += 1;
any_progress = true;
@@ -289,13 +261,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
}
} else {
// Future packet, buffer it
// log::trace!("Buffering future PN: {} (expected {})", pn, this.expected_read_pn);
if !this.reorder_buffer.contains_key(&pn) {
// Limit reorder buffer size to 1024 packets or ~4MB
if this.reorder_buffer.len() < 1024 {
this.reorder_buffer.insert(pn, payload);
} else {
warn!("Reorder buffer full, dropping PN {}", pn);
}
}
}
@@ -384,7 +353,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
if all_full {
// Try to flush - maybe it helps?
let _ = Pin::new(&mut *this).poll_flush(cx);
return Poll::Pending;
// Re-check if still full after flush attempt
all_full = true;
for sub in &this.subs {
if sub.write_buf.len() <= 64 * 1024 {
all_full = false;
break;
}
}
if all_full {
// If flush returned Ready but buffer is still full (unlikely if logic is correct),
// or if flush returned Pending, we must return Pending.
// But if flush returned Ready, we MUST have drained the buffer.
// If flush returned Pending, waker is registered.
return Poll::Pending;
}
// If not full anymore, continue to write!
}
let pn = this.next_pn;
@@ -401,19 +387,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
// Broadcast to all non-full subs
let mut sent_count = 0;
for (i, sub) in this.subs.iter_mut().enumerate() {
for sub in &mut this.subs {
if !sub.closed {
if sub.write_buf.len() <= 64 * 1024 {
sub.write_buf.extend_from_slice(&encoded_bytes);
sent_count += 1;
} else {
warn!(
"Sub {} (CID={}) is full ({} bytes), skipping PN {}",
i,
sub.cid,
sub.write_buf.len(),
pn
);
}
}
}
@@ -450,10 +428,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
sub.write_buf.advance(n);
}
Poll::Ready(Err(e)) => {
error!(
"Sub {} (CID={}) write error: {}, marking closed",
i, sub.cid, e
);
error!("Sub {} write error: {}, marking closed", i, e);
sub.closed = true;
// Don't return error yet
break;
@@ -550,7 +525,7 @@ mod tests {
async fn test_single_stream_write_read() {
use super::super::protocol::MTYP_DATA;
let (client, mut server) = tokio::io::duplex(1024);
let mut mptp = MptpStream::new(vec![client], uuid::Uuid::new_v4());
let mut mptp = MptpStream::new(vec![client]);
// Write to mptp
mptp.write_all(b"hello").await.unwrap();
@@ -573,7 +548,7 @@ mod tests {
async fn test_deduplication() {
let (c1, mut s1) = tokio::io::duplex(1024);
let (c2, mut s2) = tokio::io::duplex(1024);
let mut mptp = MptpStream::new(vec![c1, c2], uuid::Uuid::new_v4());
let mut mptp = MptpStream::new(vec![c1, c2]);
// Construct a frame
let frame = Frame::Data {
@@ -628,7 +603,7 @@ mod tests {
async fn test_resilience_to_stuck_sub() {
let (c1, mut s1) = tokio::io::duplex(1024);
let (c2, _s2) = tokio::io::duplex(1024); // s2 is never read, so c2 will become full
let mut mptp = MptpStream::new(vec![c1, c2], uuid::Uuid::new_v4());
let mut mptp = MptpStream::new(vec![c1, c2]);
// Read from s1 in a background task to keep c1 empty
tokio::spawn(async move {

View File

@@ -87,6 +87,10 @@ impl Handler {
let mut buf = BytesMut::new();
req.encode(&mut buf);
// Add a small delay to ensure handshake is sent as a distinct packet if possible
// or to allow server to accept connection properly before data
// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
if let Err(e) = stream.write_all(&buf).await {
tracing::warn!("Failed to send handshake for sub {}: {}", i, e);
return;
@@ -110,7 +114,6 @@ impl Handler {
// 3. Create MptpStream with the first stream and the receiver for subsequent ones
Ok(MptpStream::new_with_receiver_and_initial(
vec![first_stream],
cid,
rx,
))
} else {