Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7bd004f141 | ||
|
|
d44e0c68d4 | ||
|
|
bc23daa800 | ||
|
|
e6c219a61e | ||
|
|
ddc824fb9c | ||
|
|
e229d7041e | ||
|
|
e2503223dc | ||
|
|
92529635cb | ||
|
|
37e2523a36 | ||
|
|
2646115abb | ||
|
|
a68ba22714 | ||
|
|
a385766b3f |
2
go.mod
2
go.mod
@@ -9,7 +9,7 @@ require (
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
|
||||
github.com/sagernet/nftables v0.3.0-beta.4
|
||||
github.com/sagernet/sing v0.7.0-beta.1
|
||||
github.com/sagernet/sing v0.7.6
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
|
||||
golang.org/x/net v0.26.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw=
|
||||
github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.7.6 h1:6LBfDH+aI/26J3r9UHlaxTNjJeMhBpU/wrk0JKDZYI4=
|
||||
github.com/sagernet/sing v0.7.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
|
||||
@@ -458,7 +458,10 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
|
||||
|
||||
// CalculateChecksum calculates the checksum of the IPv4 header.
|
||||
func (b IPv4) CalculateChecksum() uint16 {
|
||||
return checksum.Checksum(b[:b.HeaderLength()], 0)
|
||||
// return checksum.Checksum(b[:b.HeaderLength()], 0)
|
||||
xsum0 := checksum.Checksum(b[:xsum], 0)
|
||||
xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0)
|
||||
return xsum0
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the IPv4 header.
|
||||
@@ -550,7 +553,8 @@ func (b IPv4) IsChecksumValid() bool {
|
||||
// same set of octets, including the checksum field. If the result
|
||||
// is all 1 bits (-0 in 1's complement arithmetic), the check
|
||||
// succeeds.
|
||||
return b.CalculateChecksum() == 0xffff
|
||||
//return b.CalculateChecksum() == 0xffff
|
||||
return checksum.Checksum(b[:b.HeaderLength()], 0) == 0xffff
|
||||
}
|
||||
|
||||
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
|
||||
|
||||
@@ -351,14 +351,18 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) {
|
||||
// and the checksum of the segment data.
|
||||
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
// return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum)
|
||||
xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum)
|
||||
return xsum
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the TCP header's checksum is valid.
|
||||
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
// return b.CalculateChecksum(xsum) == 0xffff
|
||||
return checksum.Checksum(b[:b.DataOffset()], xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Options returns a slice that holds the unparsed TCP options in the segment.
|
||||
|
||||
@@ -113,15 +113,18 @@ func (b UDP) SetLength(length uint16) {
|
||||
// CalculateChecksum calculates the checksum of the UDP packet, given the
|
||||
// checksum of the network-layer pseudo-header and the checksum of the payload.
|
||||
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
// Calculate the rest of the checksum.\
|
||||
// return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
xsum := checksum.Checksum(b[:udpChecksum], partialChecksum)
|
||||
xsum = checksum.Checksum(b[udpChecksum+2:UDPMinimumSize], xsum)
|
||||
return xsum
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the UDP header's checksum is valid.
|
||||
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
return checksum.Checksum(b[:UDPMinimumSize], xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the UDP header.
|
||||
|
||||
@@ -39,6 +39,10 @@ func closeAdapter(wintun *Adapter) {
|
||||
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
|
||||
// and hence a new NLA entry is created for each new adapter.
|
||||
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
|
||||
err = procWintunCloseAdapter.Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var name16 *uint16
|
||||
name16, err = windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,9 +5,9 @@ package tun
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
@@ -69,6 +69,7 @@ func (r *autoRedirect) Start() error {
|
||||
r.androidSu = true
|
||||
for _, suPath := range []string{
|
||||
"su",
|
||||
"/product/bin/su",
|
||||
"/system/bin/su",
|
||||
} {
|
||||
r.suPath, err = exec.LookPath(suPath)
|
||||
|
||||
@@ -143,12 +143,26 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
}
|
||||
}
|
||||
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
|
||||
Name: "prerouting_udp",
|
||||
Name: "prerouting_udp_icmp",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
ipProto := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetProto,
|
||||
}
|
||||
err = nft.AddSet(ipProto, []nftables.SetElement{
|
||||
{Key: []byte{unix.IPPROTO_UDP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMPV6}},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
@@ -157,10 +171,11 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_UDP},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: ipProto.ID,
|
||||
SetName: ipProto.Name,
|
||||
Invert: true,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
@@ -52,7 +52,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
|
||||
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
|
||||
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
f.tun.WritePacket(pkt)
|
||||
@@ -66,7 +66,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
|
||||
ipHdr.SetSourceAddress(inet6LoopbackAddress)
|
||||
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
f.tun.WritePacket(pkt)
|
||||
|
||||
@@ -422,14 +422,12 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
|
||||
}
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
tcpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return true, nil
|
||||
}
|
||||
@@ -470,7 +468,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
@@ -520,7 +517,6 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
|
||||
}
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
@@ -651,8 +647,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
|
||||
sourceAddress := ipHdr.SourceAddr()
|
||||
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
||||
ipHdr.SetDestinationAddr(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
ipHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return nil
|
||||
}
|
||||
@@ -686,7 +681,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
|
||||
icmpHdr := header.ICMPv4(newIPHdr.Payload())
|
||||
icmpHdr.SetType(header.ICMPv4DstUnreachable)
|
||||
icmpHdr.SetCode(code)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
copy(icmpHdr.Payload(), payload)
|
||||
if PacketOffset > 0 {
|
||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
|
||||
@@ -779,14 +774,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
udpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
@@ -820,7 +813,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(udpLen)
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type TCPNat struct {
|
||||
timeout time.Duration
|
||||
portIndex uint16
|
||||
portAccess sync.RWMutex
|
||||
addrAccess sync.RWMutex
|
||||
@@ -19,6 +20,7 @@ type TCPNat struct {
|
||||
}
|
||||
|
||||
type TCPSession struct {
|
||||
sync.Mutex
|
||||
Source netip.AddrPort
|
||||
Destination netip.AddrPort
|
||||
LastActive time.Time
|
||||
@@ -26,38 +28,41 @@ type TCPSession struct {
|
||||
|
||||
func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
|
||||
natMap := &TCPNat{
|
||||
timeout: timeout,
|
||||
portIndex: 10000,
|
||||
addrMap: make(map[netip.AddrPort]uint16),
|
||||
portMap: make(map[uint16]*TCPSession),
|
||||
}
|
||||
go natMap.loopCheckTimeout(ctx, timeout)
|
||||
go natMap.loopCheckTimeout(ctx)
|
||||
return natMap
|
||||
}
|
||||
|
||||
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
|
||||
ticker := time.NewTicker(timeout)
|
||||
func (n *TCPNat) loopCheckTimeout(ctx context.Context) {
|
||||
ticker := time.NewTicker(n.timeout)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
n.checkTimeout(timeout)
|
||||
n.checkTimeout()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TCPNat) checkTimeout(timeout time.Duration) {
|
||||
func (n *TCPNat) checkTimeout() {
|
||||
now := time.Now()
|
||||
n.portAccess.Lock()
|
||||
defer n.portAccess.Unlock()
|
||||
n.addrAccess.Lock()
|
||||
defer n.addrAccess.Unlock()
|
||||
for natPort, session := range n.portMap {
|
||||
if now.Sub(session.LastActive) > timeout {
|
||||
session.Lock()
|
||||
if now.Sub(session.LastActive) > n.timeout {
|
||||
delete(n.addrMap, session.Source)
|
||||
delete(n.portMap, natPort)
|
||||
}
|
||||
session.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
|
||||
session := n.portMap[port]
|
||||
n.portAccess.RUnlock()
|
||||
if session != nil {
|
||||
session.LastActive = time.Now()
|
||||
session.Lock()
|
||||
if time.Since(session.LastActive) > time.Second {
|
||||
session.LastActive = time.Now()
|
||||
}
|
||||
session.Unlock()
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -152,7 +152,10 @@ func (t *NativeTun) Start() error {
|
||||
|
||||
func (t *NativeTun) Close() error {
|
||||
defer flushDNSCache()
|
||||
return E.Errors(t.unsetRoutes(), t.tunFile.Close())
|
||||
t.stopFd.Stop()
|
||||
err := E.Errors(t.unsetRoutes(), t.tunFile.Close())
|
||||
t.stopFd.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *NativeTun) Read(p []byte) (n int, err error) {
|
||||
@@ -347,6 +350,9 @@ func (t *NativeTun) BatchRead() ([]*buf.Buffer, error) {
|
||||
t.buffers = t.buffers[:0]
|
||||
return nil, errno
|
||||
}
|
||||
if n < 0 {
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
if n < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
41
tun_linux.go
41
tun_linux.go
@@ -130,7 +130,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||
for _, address := range t.options.Inet4Address {
|
||||
addr4, _ := netlink.ParseAddr(address.String())
|
||||
err = netlink.AddrAdd(tunLink, addr4)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, unix.EEXIST) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||
for _, address := range t.options.Inet6Address {
|
||||
addr6, _ := netlink.ParseAddr(address.String())
|
||||
err = netlink.AddrAdd(tunLink, addr6)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, unix.EEXIST) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -315,6 +315,7 @@ func (t *NativeTun) Close() error {
|
||||
if t.interfaceCallback != nil {
|
||||
t.options.InterfaceMonitor.UnregisterCallback(t.interfaceCallback)
|
||||
}
|
||||
t.unsetAddresses()
|
||||
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
|
||||
}
|
||||
|
||||
@@ -816,14 +817,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p4 && !t.options.StrictRoute {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority
|
||||
it.IPProto = syscall.IPPROTO_ICMP
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p6 {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority6
|
||||
@@ -834,16 +827,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
}
|
||||
|
||||
if p6 && !t.options.StrictRoute {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority6
|
||||
it.IPProto = syscall.IPPROTO_ICMPV6
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
priority6++
|
||||
}
|
||||
}
|
||||
if p4 {
|
||||
it = netlink.NewRule()
|
||||
@@ -1021,6 +1004,24 @@ func (t *NativeTun) unsetRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *NativeTun) unsetAddresses() {
|
||||
if t.options.FileDescriptor > 0 {
|
||||
return
|
||||
}
|
||||
tunLink, err := netlink.LinkByName(t.options.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, address := range t.options.Inet4Address {
|
||||
addr, _ := netlink.ParseAddr(address.String())
|
||||
_ = netlink.AddrDel(tunLink, addr)
|
||||
}
|
||||
for _, address := range t.options.Inet6Address {
|
||||
addr, _ := netlink.ParseAddr(address.String())
|
||||
_ = netlink.AddrDel(tunLink, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NativeTun) resetRules() error {
|
||||
t.unsetRules()
|
||||
return t.setRules()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"github.com/sagernet/sing-tun/internal/winsys"
|
||||
"github.com/sagernet/sing-tun/internal/wintun"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/windnsapi"
|
||||
|
||||
@@ -181,6 +181,13 @@ func (t *NativeTun) Start() error {
|
||||
return err
|
||||
}
|
||||
if t.options.StrictRoute {
|
||||
major, _, _ := windows.RtlGetNtVersionNumbers()
|
||||
if major < 10 {
|
||||
if t.options.Logger != nil {
|
||||
t.options.Logger.Warn("strict routing is not supported on Windows versions below 10")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var engine uintptr
|
||||
session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
|
||||
err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
|
||||
@@ -395,15 +402,16 @@ retry:
|
||||
|
||||
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
|
||||
t.running.Add(1)
|
||||
defer t.running.Done()
|
||||
retry:
|
||||
if t.close.Load() == 1 {
|
||||
t.running.Done()
|
||||
return nil, nil, os.ErrClosed
|
||||
}
|
||||
start := nanotime()
|
||||
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
|
||||
for {
|
||||
if t.close.Load() == 1 {
|
||||
t.running.Done()
|
||||
return nil, nil, os.ErrClosed
|
||||
}
|
||||
packet, err := t.session.ReceivePacket()
|
||||
@@ -411,7 +419,10 @@ retry:
|
||||
case nil:
|
||||
packetSize := len(packet)
|
||||
t.rate.update(uint64(packetSize))
|
||||
return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil
|
||||
return packet, func() {
|
||||
t.session.ReleaseReceivePacket(packet)
|
||||
t.running.Done()
|
||||
}, nil
|
||||
case windows.ERROR_NO_MORE_ITEMS:
|
||||
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
|
||||
@@ -420,10 +431,13 @@ retry:
|
||||
procyield(1)
|
||||
continue
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
t.running.Done()
|
||||
return nil, nil, os.ErrClosed
|
||||
case windows.ERROR_INVALID_DATA:
|
||||
t.running.Done()
|
||||
return nil, nil, errors.New("send ring corrupt")
|
||||
}
|
||||
t.running.Done()
|
||||
return nil, nil, fmt.Errorf("read failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user