ping: Add needFilter

This commit is contained in:
wwqgtxx
2025-08-26 13:20:50 +08:00
parent 6e4e045620
commit 4fb5702443
2 changed files with 53 additions and 41 deletions

View File

@@ -101,20 +101,22 @@ func (d *Destination) loopRead() {
continue
}
icmpHdr := header.ICMPv4(ipHdr.Payload())
if icmpHdr.Type() != header.ICMPv4EchoReply {
continue
}
var requestExists bool
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
d.requestAccess.Lock()
_, loaded := d.requests[request]
if loaded {
requestExists = true
delete(d.requests, request)
}
d.requestAccess.Unlock()
if !requestExists {
continue
if d.needFilter() {
if icmpHdr.Type() != header.ICMPv4EchoReply {
continue
}
var requestExists bool
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
d.requestAccess.Lock()
_, loaded := d.requests[request]
if loaded {
requestExists = true
delete(d.requests, request)
}
d.requestAccess.Unlock()
if !requestExists {
continue
}
}
d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
} else {
@@ -128,20 +130,22 @@ func (d *Destination) loopRead() {
continue
}
icmpHdr := header.ICMPv6(ipHdr.Payload())
if icmpHdr.Type() != header.ICMPv6EchoReply {
continue
}
var requestExists bool
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
d.requestAccess.Lock()
_, loaded := d.requests[request]
if loaded {
requestExists = true
delete(d.requests, request)
}
d.requestAccess.Unlock()
if !requestExists {
continue
if d.needFilter() {
if icmpHdr.Type() != header.ICMPv6EchoReply {
continue
}
var requestExists bool
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
d.requestAccess.Lock()
_, loaded := d.requests[request]
if loaded {
requestExists = true
delete(d.requests, request)
}
d.requestAccess.Unlock()
if !requestExists {
continue
}
}
d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
}
@@ -163,7 +167,9 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error {
return E.New("invalid ICMPv4 header")
}
icmpHdr := header.ICMPv4(ipHdr.Payload())
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
if d.needFilter() {
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
}
d.logger.TraceContext(d.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
} else {
ipHdr := header.IPv6(packet.Bytes())
@@ -174,12 +180,18 @@ func (d *Destination) WritePacket(packet *buf.Buffer) error {
return E.New("invalid ICMPv6 header")
}
icmpHdr := header.ICMPv6(ipHdr.Payload())
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
if d.needFilter() {
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
}
d.logger.TraceContext(d.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
}
return d.conn.WriteIP(packet)
}
func (d *Destination) needFilter() bool {
return runtime.GOOS != "windows" && !d.conn.isLinuxUnprivileged()
}
func (d *Destination) registerRequest(request pingRequest) {
const requestsLimit = 1024
d.requestAccess.Lock()

View File

@@ -44,7 +44,7 @@ func Connect(ctx context.Context, privileged bool, controlFunc control.Func, des
}
func (c *Conn) connect(controlFunc control.Func) (err error) {
if c.IsLinuxUnprivileged() {
if c.isLinuxUnprivileged() {
c.conn, err = newUnprivilegedConn(c.ctx, controlFunc, c.destination)
} else {
c.conn, err = connect(c.privileged, controlFunc, c.destination)
@@ -52,12 +52,12 @@ func (c *Conn) connect(controlFunc control.Func) (err error) {
return
}
func (c *Conn) IsLinuxUnprivileged() bool {
func (c *Conn) isLinuxUnprivileged() bool {
return (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged
}
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
if c.destination.Is6() || c.IsLinuxUnprivileged() {
if c.destination.Is6() || c.isLinuxUnprivileged() {
var readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
switch conn := c.conn.(type) {
case *net.IPConn:
@@ -104,7 +104,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
}
ttl = controlMessage.TTL
}
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
@@ -142,7 +142,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
trafficClass = controlMessage.TrafficClass
}
icmpHdr := header.ICMPv6(buffer.Bytes())
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
@@ -182,7 +182,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(ipHdr.Payload())
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
@@ -194,7 +194,7 @@ func (c *Conn) ReadIP(buffer *buf.Buffer) error {
}
ipHdr.SetDestinationAddr(c.source.Load())
icmpHdr := header.ICMPv6(ipHdr.Payload())
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(0)
@@ -213,7 +213,7 @@ func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
if err != nil {
return err
}
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
buffer.Advance(int(ipHdr.HeaderLength()))
@@ -240,7 +240,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
defer buffer.Release()
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv4(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
@@ -250,7 +250,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
return common.Error(c.conn.Write(ipHdr.Payload()))
} else {
ipHdr := header.IPv6(buffer.Bytes())
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(0)
@@ -267,7 +267,7 @@ func (c *Conn) WriteIP(buffer *buf.Buffer) error {
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
defer buffer.Release()
if !c.IsLinuxUnprivileged() {
if !c.isLinuxUnprivileged() {
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())