Fix MPTP
This commit is contained in:
@@ -281,3 +281,34 @@ impl Frame {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_handshake_encoding() {
|
||||
let req = HandshakeRequest {
|
||||
ver: VER,
|
||||
cid: uuid::Uuid::new_v4(),
|
||||
cmd: CMD_CONNECT,
|
||||
dst_addr: Address::Ipv4("127.0.0.1".parse().unwrap()),
|
||||
dst_port: 8080,
|
||||
};
|
||||
|
||||
let mut buf = BytesMut::new();
|
||||
req.encode(&mut buf);
|
||||
|
||||
let mut decode_buf = buf.clone();
|
||||
let decoded = HandshakeRequest::decode(&mut decode_buf).unwrap().unwrap();
|
||||
|
||||
assert_eq!(req.ver, decoded.ver);
|
||||
assert_eq!(req.cid, decoded.cid);
|
||||
assert_eq!(req.cmd, decoded.cmd);
|
||||
assert_eq!(req.dst_port, decoded.dst_port);
|
||||
match (req.dst_addr, decoded.dst_addr) {
|
||||
(Address::Ipv4(a), Address::Ipv4(b)) => assert_eq!(a, b),
|
||||
_ => panic!("Address mismatch"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,8 +15,7 @@ struct SubConnection<S> {
|
||||
stream: S,
|
||||
read_buf: BytesMut,
|
||||
write_buf: BytesMut,
|
||||
closed: bool, // Mark if this sub-connection is dead
|
||||
cid: uuid::Uuid, // Add CID for debugging
|
||||
closed: bool, // Mark if this sub-connection is dead
|
||||
}
|
||||
|
||||
impl<S> SubConnection<S> {
|
||||
@@ -26,17 +25,6 @@ impl<S> SubConnection<S> {
|
||||
read_buf: BytesMut::with_capacity(4096),
|
||||
write_buf: BytesMut::with_capacity(4096),
|
||||
closed: false,
|
||||
cid: uuid::Uuid::nil(), // Default or placeholder
|
||||
}
|
||||
}
|
||||
|
||||
fn new_with_cid(stream: S, cid: uuid::Uuid) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
read_buf: BytesMut::with_capacity(4096),
|
||||
write_buf: BytesMut::with_capacity(4096),
|
||||
closed: false,
|
||||
cid,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -55,11 +43,8 @@ pub struct MptpStream<S> {
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
|
||||
pub fn new(streams: Vec<S>, cid: uuid::Uuid) -> Self {
|
||||
let subs = streams
|
||||
.into_iter()
|
||||
.map(|s| SubConnection::new_with_cid(s, cid))
|
||||
.collect();
|
||||
pub fn new(streams: Vec<S>) -> Self {
|
||||
let subs = streams.into_iter().map(|s| SubConnection::new(s)).collect();
|
||||
Self {
|
||||
subs,
|
||||
new_subs_rx: None,
|
||||
@@ -85,13 +70,9 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
|
||||
|
||||
pub fn new_with_receiver_and_initial(
|
||||
streams: Vec<S>,
|
||||
cid: uuid::Uuid,
|
||||
rx: mpsc::UnboundedReceiver<(S, Option<uuid::Uuid>)>,
|
||||
) -> Self {
|
||||
let subs = streams
|
||||
.into_iter()
|
||||
.map(|s| SubConnection::new_with_cid(s, cid))
|
||||
.collect();
|
||||
let subs = streams.into_iter().map(|s| SubConnection::new(s)).collect();
|
||||
Self {
|
||||
subs,
|
||||
new_subs_rx: Some(rx),
|
||||
@@ -107,12 +88,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> MptpStream<S> {
|
||||
if let Some(rx) = &mut self.new_subs_rx {
|
||||
loop {
|
||||
match rx.poll_recv(cx) {
|
||||
Poll::Ready(Some((stream, cid_opt))) => {
|
||||
if let Some(cid) = cid_opt {
|
||||
self.subs.push(SubConnection::new_with_cid(stream, cid));
|
||||
} else {
|
||||
self.subs.push(SubConnection::new(stream));
|
||||
}
|
||||
Poll::Ready(Some((stream, _))) => {
|
||||
self.subs.push(SubConnection::new(stream));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
// Channel closed, no more new subs
|
||||
@@ -179,19 +156,16 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
|
||||
all_eof = false;
|
||||
} else {
|
||||
// EOF for this sub
|
||||
warn!("Sub {} (CID={}) EOF, marking closed", i, sub.cid);
|
||||
warn!("Sub {} EOF, marking closed", i);
|
||||
sub.closed = true;
|
||||
// Don't return EOF yet, others might be alive
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
if e.kind() == io::ErrorKind::UnexpectedEof {
|
||||
warn!("Sub {} (CID={}) UnexpectedEof, marking closed", i, sub.cid);
|
||||
warn!("Sub {} UnexpectedEof, marking closed", i);
|
||||
} else {
|
||||
error!(
|
||||
"Sub {} (CID={}) read error: {}, marking closed",
|
||||
i, sub.cid, e
|
||||
);
|
||||
error!("Sub {} read error: {}, marking closed", i, e);
|
||||
}
|
||||
sub.closed = true;
|
||||
// Don't return error, just close this sub
|
||||
@@ -271,10 +245,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
|
||||
|
||||
if pn < this.expected_read_pn {
|
||||
// Duplicate or old packet, ignore
|
||||
// log::trace!("Duplicate PN: {} (expected {})", pn, this.expected_read_pn);
|
||||
} else if pn == this.expected_read_pn {
|
||||
// Expected packet
|
||||
// log::trace!("Received expected PN: {}", pn);
|
||||
this.read_buffer.extend_from_slice(&payload);
|
||||
this.expected_read_pn += 1;
|
||||
any_progress = true;
|
||||
@@ -289,13 +261,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for MptpStream<S> {
|
||||
}
|
||||
} else {
|
||||
// Future packet, buffer it
|
||||
// log::trace!("Buffering future PN: {} (expected {})", pn, this.expected_read_pn);
|
||||
if !this.reorder_buffer.contains_key(&pn) {
|
||||
// Limit reorder buffer size to 1024 packets or ~4MB
|
||||
if this.reorder_buffer.len() < 1024 {
|
||||
this.reorder_buffer.insert(pn, payload);
|
||||
} else {
|
||||
warn!("Reorder buffer full, dropping PN {}", pn);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -384,7 +353,24 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
|
||||
if all_full {
|
||||
// Try to flush - maybe it helps?
|
||||
let _ = Pin::new(&mut *this).poll_flush(cx);
|
||||
return Poll::Pending;
|
||||
|
||||
// Re-check if still full after flush attempt
|
||||
all_full = true;
|
||||
for sub in &this.subs {
|
||||
if sub.write_buf.len() <= 64 * 1024 {
|
||||
all_full = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if all_full {
|
||||
// If flush returned Ready but buffer is still full (unlikely if logic is correct),
|
||||
// or if flush returned Pending, we must return Pending.
|
||||
// But if flush returned Ready, we MUST have drained the buffer.
|
||||
// If flush returned Pending, waker is registered.
|
||||
return Poll::Pending;
|
||||
}
|
||||
// If not full anymore, continue to write!
|
||||
}
|
||||
|
||||
let pn = this.next_pn;
|
||||
@@ -401,19 +387,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
|
||||
|
||||
// Broadcast to all non-full subs
|
||||
let mut sent_count = 0;
|
||||
for (i, sub) in this.subs.iter_mut().enumerate() {
|
||||
for sub in &mut this.subs {
|
||||
if !sub.closed {
|
||||
if sub.write_buf.len() <= 64 * 1024 {
|
||||
sub.write_buf.extend_from_slice(&encoded_bytes);
|
||||
sent_count += 1;
|
||||
} else {
|
||||
warn!(
|
||||
"Sub {} (CID={}) is full ({} bytes), skipping PN {}",
|
||||
i,
|
||||
sub.cid,
|
||||
sub.write_buf.len(),
|
||||
pn
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -450,10 +428,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for MptpStream<S> {
|
||||
sub.write_buf.advance(n);
|
||||
}
|
||||
Poll::Ready(Err(e)) => {
|
||||
error!(
|
||||
"Sub {} (CID={}) write error: {}, marking closed",
|
||||
i, sub.cid, e
|
||||
);
|
||||
error!("Sub {} write error: {}, marking closed", i, e);
|
||||
sub.closed = true;
|
||||
// Don't return error yet
|
||||
break;
|
||||
@@ -550,7 +525,7 @@ mod tests {
|
||||
async fn test_single_stream_write_read() {
|
||||
use super::super::protocol::MTYP_DATA;
|
||||
let (client, mut server) = tokio::io::duplex(1024);
|
||||
let mut mptp = MptpStream::new(vec![client], uuid::Uuid::new_v4());
|
||||
let mut mptp = MptpStream::new(vec![client]);
|
||||
|
||||
// Write to mptp
|
||||
mptp.write_all(b"hello").await.unwrap();
|
||||
@@ -573,7 +548,7 @@ mod tests {
|
||||
async fn test_deduplication() {
|
||||
let (c1, mut s1) = tokio::io::duplex(1024);
|
||||
let (c2, mut s2) = tokio::io::duplex(1024);
|
||||
let mut mptp = MptpStream::new(vec![c1, c2], uuid::Uuid::new_v4());
|
||||
let mut mptp = MptpStream::new(vec![c1, c2]);
|
||||
|
||||
// Construct a frame
|
||||
let frame = Frame::Data {
|
||||
@@ -628,7 +603,7 @@ mod tests {
|
||||
async fn test_resilience_to_stuck_sub() {
|
||||
let (c1, mut s1) = tokio::io::duplex(1024);
|
||||
let (c2, _s2) = tokio::io::duplex(1024); // s2 is never read, so c2 will become full
|
||||
let mut mptp = MptpStream::new(vec![c1, c2], uuid::Uuid::new_v4());
|
||||
let mut mptp = MptpStream::new(vec![c1, c2]);
|
||||
|
||||
// Read from s1 in a background task to keep c1 empty
|
||||
tokio::spawn(async move {
|
||||
|
||||
@@ -87,6 +87,10 @@ impl Handler {
|
||||
let mut buf = BytesMut::new();
|
||||
req.encode(&mut buf);
|
||||
|
||||
// Add a small delay to ensure handshake is sent as a distinct packet if possible
|
||||
// or to allow server to accept connection properly before data
|
||||
// tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
|
||||
if let Err(e) = stream.write_all(&buf).await {
|
||||
tracing::warn!("Failed to send handshake for sub {}: {}", i, e);
|
||||
return;
|
||||
@@ -110,7 +114,6 @@ impl Handler {
|
||||
// 3. Create MptpStream with the first stream and the receiver for subsequent ones
|
||||
Ok(MptpStream::new_with_receiver_and_initial(
|
||||
vec![first_stream],
|
||||
cid,
|
||||
rx,
|
||||
))
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user