Compare commits

...

29 Commits

Author SHA1 Message Date
世界
7bd004f141 Skip strict routing on Windows versions below 10 2026-01-17 18:38:29 +08:00
世界
d44e0c68d4 Fix darwin batch read not exit on stop 2026-01-15 02:03:38 +08:00
世界
bc23daa800 Fix TUN interface restart fails with existing addresses 2026-01-04 23:24:54 +08:00
世界
e6c219a61e Fix race condition in ReadPacket 2025-12-17 19:45:13 +08:00
世界
ddc824fb9c redirect: Fix compatibility with /product/bin/su 2025-10-21 21:26:31 +08:00
世界
e229d7041e Fix race codes 2025-09-12 18:02:37 +08:00
世界
e2503223dc Prevent panic when wintun dll fails to load 2025-09-09 16:35:02 +08:00
世界
92529635cb Fix gvisor loopback address 2025-09-02 17:54:00 +08:00
世界
37e2523a36 Fix checksum changes 2025-09-02 17:54:00 +08:00
世界
2646115abb Improve checksum usages 2025-09-02 17:54:00 +08:00
世界
a68ba22714 ping: Fix linux route rules 2025-09-02 17:54:00 +08:00
世界
a385766b3f Replace usages of common/atomic 2025-09-02 17:52:13 +08:00
世界
4a56d47035 test: Add log for fw4 reload error 2025-08-03 00:12:24 +08:00
世界
07e21b9170 Fix redirect panic 2025-08-01 16:46:29 +08:00
世界
ebbe32588c Fix system stack 2025-07-21 09:44:17 +08:00
世界
0310956cc0 Fix darwin writev 2025-07-20 18:28:34 +08:00
世界
3af7305b85 Fix darwin WritePacket 2025-07-18 11:00:19 +08:00
世界
aa1fd4d994 Improve darwin tun performance 2025-07-17 23:46:08 +08:00
世界
7812930a48 Minor fixes 2025-07-13 15:58:13 +08:00
世界
4c81c8a62a Fix usages of readmsg_x 2025-07-09 00:14:47 +08:00
世界
a0881ada32 Improve darwin tun performance 2025-07-03 20:17:32 +08:00
世界
8763c24e49 Improve nftables rules for openwrt 2025-06-30 18:00:36 +08:00
世界
5e343c4b66 Add loopback address support 2025-06-20 13:14:58 +08:00
世界
f57754918d Add DefaultInterfaceMonitor.MyInterface 2025-06-20 13:14:37 +08:00
世界
2121bc3f01 Fix error usages 2025-06-20 12:47:57 +08:00
世界
bea26198e7 Fix "Fix gLazyConn race" 2025-06-16 14:01:32 +08:00
世界
3df19f464e Fix gLazyConn race 2025-06-13 18:18:53 +08:00
世界
494b0ef858 redirect: Fix unreachable 2025-06-13 18:18:53 +08:00
世界
f13cd94aa0 redirect: Fix counter position 2025-04-28 11:06:02 +08:00
37 changed files with 2767 additions and 612 deletions

2
go.mod
View File

@@ -9,7 +9,7 @@ require (
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.6.0-beta.2
github.com/sagernet/sing v0.7.6
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0

4
go.sum
View File

@@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ=
github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.7.6 h1:6LBfDH+aI/26J3r9UHlaxTNjJeMhBpU/wrk0JKDZYI4=
github.com/sagernet/sing v0.7.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=

View File

@@ -0,0 +1,668 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package fdbased provides the implementation of data-link layer endpoints
// backed by boundary-preserving file descriptors (e.g., TUN devices,
// seqpacket/datagram sockets).
//
// FD based endpoints can be used in the networking stack by calling New() to
// create a new endpoint, and then passing it as an argument to
// Stack.CreateNIC().
//
// FD based endpoints can use more than one file descriptor to read incoming
// packets. If there are more than one FDs specified and the underlying FD is an
// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the
// host kernel will consistently hash the packets to the sockets. This ensures
// that packets for the same TCP streams are not reordered.
//
// Similarly if more than one FD's are specified where the underlying FD is not
// AF_PACKET then it's the caller's responsibility to ensure that all inbound
// packets on the descriptors are consistently 5 tuple hashed to one of the
// descriptors to prevent TCP reordering.
//
// Since netstack today does not compute 5 tuple hashes for outgoing packets we
// only use the first FD to write outbound packets. Once 5 tuple hashes for
// all outbound packets are available we will make use of all underlying FD's to
// write outbound packets.
package fdbased
import (
"fmt"
"runtime"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/sync"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
"github.com/sagernet/sing/common"
"golang.org/x/sys/unix"
)
// linkDispatcher reads packets from the link FD and dispatches them to the
// NetworkDispatcher.
type linkDispatcher interface {
Stop()
dispatch() (bool, tcpip.Error)
release()
}
// PacketDispatchMode are the various supported methods of receiving and
// dispatching packets from the underlying FD.
type PacketDispatchMode int
// BatchSize is the number of packets to write in each syscall. It is 47
// because when GVisorGSO is in use then a single 65KB TCP segment can get
// split into 46 segments of 1420 bytes and a single 216 byte segment.
const BatchSize = 47
const (
// Readv is the default dispatch mode and is the least performant of the
// dispatch options but the one that is supported by all underlying FD
// types.
Readv PacketDispatchMode = iota
// RecvMMsg enables use of recvmmsg() syscall instead of readv() to
// read inbound packets. This reduces # of syscalls needed to process
// packets.
//
// NOTE: recvmmsg() is only supported for sockets, so if the underlying
// FD is not a socket then the code will still fall back to the readv()
// path.
RecvMMsg
// PacketMMap enables use of PACKET_RX_RING to receive packets from the
// NIC. PacketMMap requires that the underlying FD be an AF_PACKET. The
// primary use-case for this is runsc which uses an AF_PACKET FD to
// receive packets from the veth device.
PacketMMap
)
func (p PacketDispatchMode) String() string {
switch p {
case Readv:
return "Readv"
case RecvMMsg:
return "RecvMMsg"
case PacketMMap:
return "PacketMMap"
default:
return fmt.Sprintf("unknown packet dispatch mode '%d'", p)
}
}
var (
_ stack.LinkEndpoint = (*endpoint)(nil)
_ stack.GSOEndpoint = (*endpoint)(nil)
)
// +stateify savable
type fdInfo struct {
fd int
isSocket bool
}
// +stateify savable
type endpoint struct {
// fds is the set of file descriptors each identifying one inbound/outbound
// channel. The endpoint will dispatch from all inbound channels as well as
// hash outbound packets to specific channels based on the packet hash.
fds []fdInfo
// hdrSize specifies the link-layer header size. If set to 0, no header
// is added/removed; otherwise an ethernet header is used.
hdrSize int
// caps holds the endpoint capabilities.
caps stack.LinkEndpointCapabilities
// closed is a function to be called when the FD's peer (if any) closes
// its end of the communication pipe.
closed func(tcpip.Error) `state:"nosave"`
inboundDispatchers []linkDispatcher
mu endpointRWMutex `state:"nosave"`
// +checklocks:mu
dispatcher stack.NetworkDispatcher
// packetDispatchMode controls the packet dispatcher used by this
// endpoint.
packetDispatchMode PacketDispatchMode
// wg keeps track of running goroutines.
wg sync.WaitGroup `state:"nosave"`
// maxSyscallHeaderBytes has the same meaning as
// Options.MaxSyscallHeaderBytes.
maxSyscallHeaderBytes uintptr
// writevMaxIovs is the maximum number of iovecs that may be passed to
// rawfile.NonBlockingWriteIovec, as possibly limited by
// maxSyscallHeaderBytes. (No analogous limit is defined for
// rawfile.NonBlockingSendMMsg, since in that case the maximum number of
// iovecs also depends on the number of mmsghdrs. Instead, if sendBatch
// encounters a packet whose iovec count is limited by
// maxSyscallHeaderBytes, it falls back to writing the packet using writev
// via WritePacket.)
writevMaxIovs int
// addr is the address of the endpoint.
//
// +checklocks:mu
addr tcpip.LinkAddress
// mtu (maximum transmission unit) is the maximum size of a packet.
// +checklocks:mu
mtu uint32
batchSize int
sendMsgX bool
}
// Options specify the details about the fd-based endpoint to be created.
//
// +stateify savable
type Options struct {
// FDs is a set of FDs used to read/write packets.
FDs []int
// MTU is the mtu to use for this endpoint.
MTU uint32
// EthernetHeader if true, indicates that the endpoint should read/write
// ethernet frames instead of IP packets.
EthernetHeader bool
// ClosedFunc is a function to be called when an endpoint's peer (if
// any) closes its end of the communication pipe.
ClosedFunc func(tcpip.Error)
// Address is the link address for this endpoint. Only used if
// EthernetHeader is true.
Address tcpip.LinkAddress
// SaveRestore if true, indicates that this NIC capability set should
// include CapabilitySaveRestore
SaveRestore bool
// DisconnectOk if true, indicates that this NIC capability set should
// include CapabilityDisconnectOk.
DisconnectOk bool
// PacketDispatchMode specifies the type of inbound dispatcher to be
// used for this endpoint.
PacketDispatchMode PacketDispatchMode
// TXChecksumOffload if true, indicates that this endpoints capability
// set should include CapabilityTXChecksumOffload.
TXChecksumOffload bool
// RXChecksumOffload if true, indicates that this endpoints capability
// set should include CapabilityRXChecksumOffload.
RXChecksumOffload bool
// If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes
// of struct iovec, msghdr, and mmsghdr that may be passed by each host
// system call.
MaxSyscallHeaderBytes int
// InterfaceIndex is the interface index of the underlying device.
InterfaceIndex int
// ProcessorsPerChannel is the number of goroutines used to handle packets
// from each FD.
ProcessorsPerChannel int
MultiPendingPackets bool
SendMsgX bool
}
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
// open for the lifetime of the returned endpoint (until after the endpoint has
// stopped being using and Wait returns).
func New(opts *Options) (stack.LinkEndpoint, error) {
caps := stack.LinkEndpointCapabilities(0)
if opts.RXChecksumOffload {
caps |= stack.CapabilityRXChecksumOffload
}
if opts.TXChecksumOffload {
caps |= stack.CapabilityTXChecksumOffload
}
hdrSize := 0
if opts.EthernetHeader {
hdrSize = header.EthernetMinimumSize
caps |= stack.CapabilityResolutionRequired
}
if opts.SaveRestore {
caps |= stack.CapabilitySaveRestore
}
if opts.DisconnectOk {
caps |= stack.CapabilityDisconnectOk
}
if len(opts.FDs) == 0 {
return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
}
if opts.MaxSyscallHeaderBytes < 0 {
return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative")
}
var batchSize int
if opts.MultiPendingPackets {
batchSize = int((512*1024)/(opts.MTU)) + 1
} else {
batchSize = 1
}
e := &endpoint{
mtu: opts.MTU,
caps: caps,
closed: opts.ClosedFunc,
addr: opts.Address,
hdrSize: hdrSize,
packetDispatchMode: opts.PacketDispatchMode,
maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes),
writevMaxIovs: rawfile.MaxIovs,
batchSize: batchSize,
sendMsgX: opts.SendMsgX,
}
if e.maxSyscallHeaderBytes != 0 {
if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs {
e.writevMaxIovs = max
}
}
// Create per channel dispatchers.
for _, fd := range opts.FDs {
if err := unix.SetNonblock(fd, true); err != nil {
return nil, fmt.Errorf("unix.SetNonblock(%v) failed: %v", fd, err)
}
e.fds = append(e.fds, fdInfo{fd: fd, isSocket: true})
if opts.ProcessorsPerChannel == 0 {
opts.ProcessorsPerChannel = common.Max(1, runtime.GOMAXPROCS(0)/len(opts.FDs))
}
inboundDispatcher, err := newRecvMMsgDispatcher(fd, e, opts)
if err != nil {
return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
}
e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
}
return e, nil
}
// Attach launches the goroutine that reads packets from the file descriptor and
// dispatches them via the provided dispatcher. If one is already attached,
// then nothing happens.
//
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.mu.Lock()
// nil means the NIC is being removed.
if dispatcher == nil && e.dispatcher != nil {
for _, dispatcher := range e.inboundDispatchers {
dispatcher.Stop()
}
e.dispatcher = nil
// NOTE(gvisor.dev/issue/11456): Unlock e.mu before e.Wait().
e.mu.Unlock()
e.Wait()
return
}
defer e.mu.Unlock()
if dispatcher != nil && e.dispatcher == nil {
e.dispatcher = dispatcher
// Link endpoints are not savable. When transportation endpoints are
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
for i := range e.inboundDispatchers {
e.wg.Add(1)
go func(i int) { // S/R-SAFE: See above.
e.dispatchLoop(e.inboundDispatchers[i])
e.wg.Done()
}(i)
}
}
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *endpoint) IsAttached() bool {
e.mu.RLock()
defer e.mu.RUnlock()
return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU.
func (e *endpoint) MTU() uint32 {
e.mu.RLock()
defer e.mu.RUnlock()
return e.mtu
}
// SetMTU implements stack.LinkEndpoint.SetMTU.
func (e *endpoint) SetMTU(mtu uint32) {
e.mu.Lock()
defer e.mu.Unlock()
e.mtu = mtu
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.caps
}
// MaxHeaderLength returns the maximum size of the link-layer header.
func (e *endpoint) MaxHeaderLength() uint16 {
return uint16(e.hdrSize)
}
// LinkAddress returns the link address of this endpoint.
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
e.mu.RLock()
defer e.mu.RUnlock()
return e.addr
}
// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress.
func (e *endpoint) SetLinkAddress(addr tcpip.LinkAddress) {
e.mu.Lock()
defer e.mu.Unlock()
e.addr = addr
}
// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
// reading from its FD.
func (e *endpoint) Wait() {
e.wg.Wait()
}
// AddHeader implements stack.LinkEndpoint.AddHeader.
func (e *endpoint) AddHeader(pkt *stack.PacketBuffer) {
if e.hdrSize > 0 {
// Add ethernet header if needed.
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
eth.Encode(&header.EthernetFields{
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
Type: pkt.NetworkProtocolNumber,
})
}
}
func (e *endpoint) parseHeader(pkt *stack.PacketBuffer) (header.Ethernet, bool) {
if e.hdrSize <= 0 {
return nil, true
}
hdrBytes, ok := pkt.LinkHeader().Consume(e.hdrSize)
if !ok {
return nil, false
}
hdr := header.Ethernet(hdrBytes)
pkt.NetworkProtocolNumber = hdr.Type()
return hdr, true
}
// parseInboundHeader parses the link header of pkt and returns true if the
// header is well-formed and sent to this endpoint's MAC or the broadcast
// address.
func (e *endpoint) parseInboundHeader(pkt *stack.PacketBuffer, wantAddr tcpip.LinkAddress) bool {
hdr, ok := e.parseHeader(pkt)
if !ok || e.hdrSize <= 0 {
return ok
}
dstAddr := hdr.DestinationAddress()
// Per RFC 9542 2.1 on the least significant bit of the first octet of
// a MAC address: "If it is zero, the MAC address is unicast. If it is
// a one, the address is groupcast (multicast or broadcast)." Multicast
// and broadcast are the same thing to ethernet; they are both sent to
// everyone.
return dstAddr == wantAddr || byte(dstAddr[0])&0x01 == 1
}
// ParseHeader implements stack.LinkEndpoint.ParseHeader.
func (e *endpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
_, ok := e.parseHeader(pkt)
return ok
}
var (
packetHeader4 = []byte{0x00, 0x00, 0x00, unix.AF_INET}
packetHeader6 = []byte{0x00, 0x00, 0x00, unix.AF_INET6}
)
// writePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
func (e *endpoint) writePacket(pkt *stack.PacketBuffer) tcpip.Error {
fdInfo := e.fds[pkt.Hash%uint32(len(e.fds))]
fd := fdInfo.fd
var vnetHdrBuf []byte
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
vnetHdrBuf = packetHeader4
} else {
vnetHdrBuf = packetHeader6
}
views := pkt.AsSlices()
numIovecs := len(views)
if len(vnetHdrBuf) != 0 {
numIovecs++
}
if numIovecs > e.writevMaxIovs {
numIovecs = e.writevMaxIovs
}
// Allocate small iovec arrays on the stack.
var iovecsArr [8]unix.Iovec
iovecs := iovecsArr[:0]
if numIovecs > len(iovecsArr) {
iovecs = make([]unix.Iovec, 0, numIovecs)
}
iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
for _, v := range views {
iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
}
if errno := rawfile.NonBlockingWriteIovec(fd, iovecs); errno != 0 {
return TranslateErrno(errno)
}
return nil
}
func (e *endpoint) sendBatch(batchFDInfo fdInfo, pkts []*stack.PacketBuffer) (int, tcpip.Error) {
// Degrade to writePacket if underlying fd is not a socket.
if !batchFDInfo.isSocket || !e.sendMsgX {
var written int
var err tcpip.Error
for written < len(pkts) {
if err = e.writePacket(pkts[written]); err != nil {
break
}
written++
}
return written, err
}
// Send a batch of packets through batchFD.
batchFD := batchFDInfo.fd
mmsgHdrsStorage := make([]rawfile.MsgHdrX, 0, len(pkts))
packets := 0
for packets < len(pkts) {
mmsgHdrs := mmsgHdrsStorage
batch := pkts[packets:]
syscallHeaderBytes := uintptr(0)
for _, pkt := range batch {
var vnetHdrBuf []byte
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
vnetHdrBuf = packetHeader4
} else {
vnetHdrBuf = packetHeader6
}
views, offset := pkt.AsViewList()
var skipped int
var view *buffer.View
for view = views.Front(); view != nil && offset >= view.Size(); view = view.Next() {
offset -= view.Size()
skipped++
}
// We've made it to the usable views.
numIovecs := views.Len() - skipped
if len(vnetHdrBuf) != 0 {
numIovecs++
}
if numIovecs > rawfile.MaxIovs {
numIovecs = rawfile.MaxIovs
}
if e.maxSyscallHeaderBytes != 0 {
syscallHeaderBytes += rawfile.SizeofMsgHdrX + uintptr(numIovecs)*rawfile.SizeofIovec
if syscallHeaderBytes > e.maxSyscallHeaderBytes {
// We can't fit this packet into this call to sendmmsg().
// We could potentially do so if we reduced numIovecs
// further, but this might incur considerable extra
// copying. Leave it to the next batch instead.
break
}
}
// We can't easily allocate iovec arrays on the stack here since
// they will escape this loop iteration via mmsgHdrs.
iovecs := make([]unix.Iovec, 0, numIovecs)
iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
// At most one slice has a non-zero offset.
iovecs = rawfile.AppendIovecFromBytes(iovecs, view.AsSlice()[offset:], numIovecs)
for view = view.Next(); view != nil; view = view.Next() {
iovecs = rawfile.AppendIovecFromBytes(iovecs, view.AsSlice(), numIovecs)
}
var mmsgHdr rawfile.MsgHdrX
mmsgHdr.Msg.Iov = &iovecs[0]
mmsgHdr.Msg.SetIovlen(len(iovecs))
// mmsgHdr.DataLen = uint32(len(iovecs))
mmsgHdrs = append(mmsgHdrs, mmsgHdr)
}
if len(mmsgHdrs) == 0 {
// We can't fit batch[0] into a mmsghdr while staying under
// e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the
// mmsghdr (by using writev) and re-buffer iovecs more aggressively
// if necessary (by using e.writevMaxIovs instead of
// rawfile.MaxIovs).
pkt := batch[0]
if err := e.writePacket(pkt); err != nil {
return packets, err
}
packets++
} else {
for len(mmsgHdrs) > 0 {
sent, errno := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
if errno != 0 {
return packets, TranslateErrno(errno)
}
packets += sent
mmsgHdrs = mmsgHdrs[sent:]
}
}
}
return packets, nil
}
// WritePackets writes outbound packets to the underlying file descriptors. If
// one is not currently writable, the packet is dropped.
//
// Being a batch API, each packet in pkts should have the following
// fields populated:
// - pkt.EgressRoute
// - pkt.GSOOptions
// - pkt.NetworkProtocolNumber
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
// Preallocate to avoid repeated reallocation as we append to batch.
batch := make([]*stack.PacketBuffer, 0, e.batchSize)
batchFDInfo := fdInfo{fd: -1, isSocket: false}
sentPackets := 0
for _, pkt := range pkts.AsSlice() {
if len(batch) == 0 {
batchFDInfo = e.fds[pkt.Hash%uint32(len(e.fds))]
}
pktFDInfo := e.fds[pkt.Hash%uint32(len(e.fds))]
if sendNow := pktFDInfo != batchFDInfo; !sendNow {
batch = append(batch, pkt)
continue
}
n, err := e.sendBatch(batchFDInfo, batch)
sentPackets += n
if err != nil {
return sentPackets, err
}
batch = batch[:0]
batch = append(batch, pkt)
batchFDInfo = pktFDInfo
}
if len(batch) != 0 {
n, err := e.sendBatch(batchFDInfo, batch)
sentPackets += n
if err != nil {
return sentPackets, err
}
}
return sentPackets, nil
}
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
for {
cont, err := inboundDispatcher.dispatch()
if err != nil || !cont {
if e.closed != nil {
e.closed(err)
}
inboundDispatcher.release()
return err
}
}
}
// GSOMaxSize implements stack.GSOEndpoint.
func (e *endpoint) GSOMaxSize() uint32 {
return 0
}
// SupportedGSO implements stack.GSOEndpoint.
func (e *endpoint) SupportedGSO() stack.SupportedGSO {
return stack.GSONotSupported
}
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
if e.hdrSize > 0 {
return header.ARPHardwareEther
}
return header.ARPHardwareNone
}
// Close implements stack.LinkEndpoint.
func (e *endpoint) Close() {}
// SetOnCloseAction implements stack.LinkEndpoint.
func (*endpoint) SetOnCloseAction(func()) {}

View File

@@ -0,0 +1,96 @@
package fdbased
import (
"reflect"
"github.com/sagernet/gvisor/pkg/sync"
"github.com/sagernet/gvisor/pkg/sync/locking"
)
// RWMutex is sync.RWMutex with the correctness validator.
type endpointRWMutex struct {
mu sync.RWMutex
}
// lockNames is a list of user-friendly lock names.
// Populated in init.
var endpointlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type endpointlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *endpointRWMutex) Lock() {
locking.AddGLock(endpointprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *endpointRWMutex) NestedLock(i endpointlockNameIndex) {
locking.AddGLock(endpointprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *endpointRWMutex) Unlock() {
m.mu.Unlock()
locking.DelGLock(endpointprefixIndex, -1)
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *endpointRWMutex) NestedUnlock(i endpointlockNameIndex) {
m.mu.Unlock()
locking.DelGLock(endpointprefixIndex, int(i))
}
// RLock locks m for reading.
// +checklocksignore
func (m *endpointRWMutex) RLock() {
locking.AddGLock(endpointprefixIndex, -1)
m.mu.RLock()
}
// RUnlock undoes a single RLock call.
// +checklocksignore
func (m *endpointRWMutex) RUnlock() {
m.mu.RUnlock()
locking.DelGLock(endpointprefixIndex, -1)
}
// RLockBypass locks m for reading without executing the validator.
// +checklocksignore
func (m *endpointRWMutex) RLockBypass() {
m.mu.RLock()
}
// RUnlockBypass undoes a single RLockBypass call.
// +checklocksignore
func (m *endpointRWMutex) RUnlockBypass() {
m.mu.RUnlock()
}
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
// +checklocksignore
func (m *endpointRWMutex) DowngradeLock() {
m.mu.DowngradeLock()
}
var endpointprefixIndex *locking.MutexClass
// DO NOT REMOVE: The following function is automatically replaced.
func endpointinitLockNames() {}
func init() {
endpointinitLockNames()
endpointprefixIndex = locking.NewMutexClass(reflect.TypeOf(endpointRWMutex{}), endpointlockNames)
}

View File

@@ -0,0 +1,54 @@
package fdbased
import (
"github.com/sagernet/gvisor/pkg/tcpip"
"golang.org/x/sys/unix"
)
func TranslateErrno(e unix.Errno) tcpip.Error {
switch e {
case unix.EEXIST:
return &tcpip.ErrDuplicateAddress{}
case unix.ENETUNREACH:
return &tcpip.ErrHostUnreachable{}
case unix.EINVAL:
return &tcpip.ErrInvalidEndpointState{}
case unix.EALREADY:
return &tcpip.ErrAlreadyConnecting{}
case unix.EISCONN:
return &tcpip.ErrAlreadyConnected{}
case unix.EADDRINUSE:
return &tcpip.ErrPortInUse{}
case unix.EADDRNOTAVAIL:
return &tcpip.ErrBadLocalAddress{}
case unix.EPIPE:
return &tcpip.ErrClosedForSend{}
case unix.EWOULDBLOCK:
return &tcpip.ErrWouldBlock{}
case unix.ECONNREFUSED:
return &tcpip.ErrConnectionRefused{}
case unix.ETIMEDOUT:
return &tcpip.ErrTimeout{}
case unix.EINPROGRESS:
return &tcpip.ErrConnectStarted{}
case unix.EDESTADDRREQ:
return &tcpip.ErrDestinationRequired{}
case unix.ENOTSUP:
return &tcpip.ErrNotSupported{}
case unix.ENOTTY:
return &tcpip.ErrQueueSizeNotSupported{}
case unix.ENOTCONN:
return &tcpip.ErrNotConnected{}
case unix.ECONNRESET:
return &tcpip.ErrConnectionReset{}
case unix.ECONNABORTED:
return &tcpip.ErrConnectionAborted{}
case unix.EMSGSIZE:
return &tcpip.ErrMessageTooLong{}
case unix.ENOBUFS:
return &tcpip.ErrNoBufferSpace{}
default:
return &tcpip.ErrInvalidEndpointState{}
}
}

View File

@@ -0,0 +1,199 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fdbased
import (
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/stack/gro"
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
"github.com/sagernet/sing-tun/internal/stopfd_darwin"
"golang.org/x/sys/unix"
)
type iovecBuffer struct {
mtu int
views []*buffer.View
iovecs []unix.Iovec `state:"nosave"`
}
func newIovecBuffer(mtu uint32) *iovecBuffer {
b := &iovecBuffer{
mtu: int(mtu),
views: make([]*buffer.View, 2),
iovecs: make([]unix.Iovec, 2),
}
return b
}
func (b *iovecBuffer) nextIovecs() []unix.Iovec {
if b.views[0] == nil {
b.views[0] = buffer.NewViewSize(4)
b.iovecs[0] = unix.Iovec{Base: b.views[0].BasePtr()}
b.iovecs[0].SetLen(4)
}
if b.views[1] == nil {
b.views[1] = buffer.NewViewSize(b.mtu)
b.iovecs[1] = unix.Iovec{Base: b.views[1].BasePtr()}
b.iovecs[1].SetLen(b.mtu)
}
return b.iovecs
}
// pullBuffer extracts the enough underlying storage from b.buffer to hold n
// bytes. It removes this storage from b.buffer, returns a new buffer
// that holds the storage, and updates pulledIndex to indicate which part
// of b.buffer's storage must be reallocated during the next call to
// nextIovecs.
func (b *iovecBuffer) pullBuffer(n int) buffer.Buffer {
pulled := buffer.Buffer{}
pulled.Append(b.views[0])
pulled.Append(b.views[1])
pulled.Truncate(int64(n))
pulled.TrimFront(4)
b.views[0] = nil
b.views[1] = nil
return pulled
}
func (b *iovecBuffer) release() {
for _, v := range b.views {
if v != nil {
v.Release()
v = nil
}
}
}
// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
// dispatches them.
//
// +stateify savable
type recvMMsgDispatcher struct {
stopfd.StopFD
// fd is the file descriptor used to send and receive packets.
fd int
// e is the endpoint this dispatcher is attached to.
e *endpoint
// bufs is an array of iovec buffers that contain packet contents.
bufs []*iovecBuffer
// msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
// reference an array of iovecs in the iovecs field defined above. This
// array is passed as the parameter to recvmmsg call to retrieve
// potentially more than 1 packet per unix.
msgHdrs []rawfile.MsgHdrX `state:"nosave"`
// pkts is reused to avoid allocations.
pkts stack.PacketBufferList
// gro coalesces incoming packets to increase throughput.
gro gro.GRO
// mgr is the processor goroutine manager.
mgr *processorManager
}
func newRecvMMsgDispatcher(fd int, e *endpoint, opts *Options) (linkDispatcher, error) {
stopFD, err := stopfd.New()
if err != nil {
return nil, err
}
var batchSize int
if opts.MultiPendingPackets {
batchSize = int((512*1024)/(opts.MTU)) + 1
} else {
batchSize = 1
}
d := &recvMMsgDispatcher{
StopFD: stopFD,
fd: fd,
e: e,
bufs: make([]*iovecBuffer, batchSize),
msgHdrs: make([]rawfile.MsgHdrX, batchSize),
}
for i := range d.bufs {
d.bufs[i] = newIovecBuffer(opts.MTU)
}
d.gro.Init(false)
d.mgr = newProcessorManager(opts, e)
d.mgr.start()
return d, nil
}
func (d *recvMMsgDispatcher) release() {
for _, iov := range d.bufs {
iov.release()
}
d.mgr.close()
}
// recvMMsgDispatch reads more than one packet at a time from the file
// descriptor and dispatches it.
func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
// Fill message headers.
for k := range d.msgHdrs {
iovecs := d.bufs[k].nextIovecs()
iovLen := len(iovecs)
// Cannot clear only the length field. Older versions of the darwin kernel will check whether other data is empty.
// https://github.com/Darm64/XNU/blob/xnu-2782.40.9/bsd/kern/uipc_syscalls.c#L2026-L2048
d.msgHdrs[k] = rawfile.MsgHdrX{}
d.msgHdrs[k].Msg.Iov = &iovecs[0]
d.msgHdrs[k].Msg.SetIovlen(iovLen)
}
nMsgs, errno := rawfile.BlockingRecvMMsgUntilStopped(d.ReadFD, d.fd, d.msgHdrs)
if errno != 0 {
return false, TranslateErrno(errno)
}
if nMsgs == -1 {
return false, nil
}
// Process each of received packets.
d.e.mu.RLock()
addr := d.e.addr
dsp := d.e.dispatcher
d.e.mu.RUnlock()
d.gro.Dispatcher = dsp
defer d.pkts.Reset()
for k := 0; k < nMsgs; k++ {
n := int(d.msgHdrs[k].DataLen)
payload := d.bufs[k].pullBuffer(n)
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: payload,
})
d.pkts.PushBack(pkt)
// Mark that this iovec has been processed.
d.msgHdrs[k].Msg.Iovlen = 0
if d.e.parseInboundHeader(pkt, addr) {
pkt.RXChecksumValidated = d.e.caps&stack.CapabilityRXChecksumOffload != 0
d.mgr.queuePacket(pkt, d.e.hdrSize > 0)
}
}
d.mgr.wakeReady()
return true, nil
}

View File

@@ -0,0 +1,64 @@
package fdbased
import (
"reflect"
"github.com/sagernet/gvisor/pkg/sync"
"github.com/sagernet/gvisor/pkg/sync/locking"
)
// Mutex is sync.Mutex with the correctness validator.
type processorMutex struct {
mu sync.Mutex
}
var processorprefixIndex *locking.MutexClass
// lockNames is a list of user-friendly lock names.
// Populated in init.
var processorlockNames []string
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
// referring to an index within lockNames.
// Values are specified using the "consts" field of go_template_instance.
type processorlockNameIndex int
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
// LOCK_NAME_INDEX_CONSTANTS
const ()
// Lock locks m.
// +checklocksignore
func (m *processorMutex) Lock() {
locking.AddGLock(processorprefixIndex, -1)
m.mu.Lock()
}
// NestedLock locks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *processorMutex) NestedLock(i processorlockNameIndex) {
locking.AddGLock(processorprefixIndex, int(i))
m.mu.Lock()
}
// Unlock unlocks m.
// +checklocksignore
func (m *processorMutex) Unlock() {
locking.DelGLock(processorprefixIndex, -1)
m.mu.Unlock()
}
// NestedUnlock unlocks m knowing that another lock of the same type is held.
// +checklocksignore
func (m *processorMutex) NestedUnlock(i processorlockNameIndex) {
locking.DelGLock(processorprefixIndex, int(i))
m.mu.Unlock()
}
// DO NOT REMOVE: The following function is automatically replaced.
func processorinitLockNames() {}
func init() {
processorinitLockNames()
processorprefixIndex = locking.NewMutexClass(reflect.TypeOf(processorMutex{}), processorlockNames)
}

View File

@@ -0,0 +1,275 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fdbased
import (
"context"
"encoding/binary"
"github.com/sagernet/gvisor/pkg/rand"
"github.com/sagernet/gvisor/pkg/sleep"
"github.com/sagernet/gvisor/pkg/sync"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/hash/jenkins"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/stack/gro"
)
// +stateify savable
type processor struct {
mu processorMutex `state:"nosave"`
// +checklocks:mu
pkts stack.PacketBufferList
e *endpoint
gro gro.GRO
sleeper sleep.Sleeper
packetWaker sleep.Waker
closeWaker sleep.Waker
}
func (p *processor) start(wg *sync.WaitGroup) {
defer wg.Done()
defer p.sleeper.Done()
for {
switch w := p.sleeper.Fetch(true); {
case w == &p.packetWaker:
p.deliverPackets()
case w == &p.closeWaker:
p.mu.Lock()
p.pkts.Reset()
p.mu.Unlock()
return
}
}
}
func (p *processor) deliverPackets() {
p.e.mu.RLock()
p.gro.Dispatcher = p.e.dispatcher
p.e.mu.RUnlock()
if p.gro.Dispatcher == nil {
p.mu.Lock()
p.pkts.Reset()
p.mu.Unlock()
return
}
p.mu.Lock()
for p.pkts.Len() > 0 {
pkt := p.pkts.PopFront()
p.mu.Unlock()
p.gro.Enqueue(pkt)
pkt.DecRef()
p.mu.Lock()
}
p.mu.Unlock()
p.gro.Flush()
}
// processorManager handles starting, closing, and queuing packets on processor
// goroutines.
//
// +stateify savable
type processorManager struct {
processors []processor
seed uint32
wg sync.WaitGroup `state:"nosave"`
e *endpoint
ready []bool
}
// newProcessorManager creates a new processor manager.
func newProcessorManager(opts *Options, e *endpoint) *processorManager {
m := &processorManager{}
m.seed = rand.Uint32()
m.ready = make([]bool, opts.ProcessorsPerChannel)
m.processors = make([]processor, opts.ProcessorsPerChannel)
m.e = e
m.wg.Add(opts.ProcessorsPerChannel)
for i := range m.processors {
p := &m.processors[i]
p.sleeper.AddWaker(&p.packetWaker)
p.sleeper.AddWaker(&p.closeWaker)
p.gro.Init(false)
p.e = e
}
return m
}
// start starts the processor goroutines if the processor manager is configured
// with more than one processor.
func (m *processorManager) start() {
for i := range m.processors {
p := &m.processors[i]
// Only start processor in a separate goroutine if we have multiple of them.
if len(m.processors) > 1 {
go p.start(&m.wg)
}
}
}
// afterLoad is invoked by stateify.
func (m *processorManager) afterLoad(context.Context) {
m.wg.Add(len(m.processors))
m.start()
}
func (m *processorManager) connectionHash(cid *connectionID) uint32 {
var payload [4]byte
binary.LittleEndian.PutUint16(payload[0:], cid.srcPort)
binary.LittleEndian.PutUint16(payload[2:], cid.dstPort)
h := jenkins.Sum32(m.seed)
h.Write(payload[:])
h.Write(cid.srcAddr)
h.Write(cid.dstAddr)
return h.Sum32()
}
// queuePacket queues a packet to be delivered to the appropriate processor.
func (m *processorManager) queuePacket(pkt *stack.PacketBuffer, hasEthHeader bool) {
var pIdx uint32
cid, nonConnectionPkt := tcpipConnectionID(pkt)
if !hasEthHeader {
if nonConnectionPkt {
// If there's no eth header this should be a standard tcpip packet. If
// it isn't the packet is invalid so drop it.
return
}
pkt.NetworkProtocolNumber = cid.proto
}
if len(m.processors) == 1 || nonConnectionPkt {
// If the packet is not associated with an active connection, use the
// first processor.
pIdx = 0
} else {
pIdx = m.connectionHash(&cid) % uint32(len(m.processors))
}
p := &m.processors[pIdx]
p.mu.Lock()
defer p.mu.Unlock()
p.pkts.PushBack(pkt.IncRef())
m.ready[pIdx] = true
}
type connectionID struct {
srcAddr, dstAddr []byte
srcPort, dstPort uint16
proto tcpip.NetworkProtocolNumber
}
// tcpipConnectionID returns a tcpip connection id tuple based on the data found
// in the packet. It returns true if the packet is not associated with an active
// connection (e.g ARP, NDP, etc). The method assumes link headers have already
// been processed if they were present.
func tcpipConnectionID(pkt *stack.PacketBuffer) (connectionID, bool) {
var cid connectionID
h, ok := pkt.Data().PullUp(1)
if !ok {
// Skip this packet.
return cid, true
}
const tcpSrcDstPortLen = 4
switch header.IPVersion(h) {
case header.IPv4Version:
hdrLen := header.IPv4(h).HeaderLength()
h, ok = pkt.Data().PullUp(int(hdrLen) + tcpSrcDstPortLen)
if !ok {
return cid, true
}
ipHdr := header.IPv4(h[:hdrLen])
tcpHdr := header.TCP(h[hdrLen:][:tcpSrcDstPortLen])
cid.srcAddr = ipHdr.SourceAddressSlice()
cid.dstAddr = ipHdr.DestinationAddressSlice()
// All fragment packets need to be processed by the same goroutine, so
// only record the TCP ports if this is not a fragment packet.
if ipHdr.IsValid(pkt.Data().Size()) && !ipHdr.More() && ipHdr.FragmentOffset() == 0 {
cid.srcPort = tcpHdr.SourcePort()
cid.dstPort = tcpHdr.DestinationPort()
}
cid.proto = header.IPv4ProtocolNumber
case header.IPv6Version:
h, ok = pkt.Data().PullUp(header.IPv6FixedHeaderSize + tcpSrcDstPortLen)
if !ok {
return cid, true
}
ipHdr := header.IPv6(h)
var tcpHdr header.TCP
if tcpip.TransportProtocolNumber(ipHdr.NextHeader()) == header.TCPProtocolNumber {
tcpHdr = header.TCP(h[header.IPv6FixedHeaderSize:][:tcpSrcDstPortLen])
} else {
// Slow path for IPv6 extension headers :(.
dataBuf := pkt.Data().ToBuffer()
dataBuf.TrimFront(header.IPv6MinimumSize)
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataBuf)
defer it.Release()
for {
hdr, done, err := it.Next()
if done || err != nil {
break
}
hdr.Release()
}
h, ok = pkt.Data().PullUp(int(it.HeaderOffset()) + tcpSrcDstPortLen)
if !ok {
return cid, true
}
tcpHdr = header.TCP(h[it.HeaderOffset():][:tcpSrcDstPortLen])
}
cid.srcAddr = ipHdr.SourceAddressSlice()
cid.dstAddr = ipHdr.DestinationAddressSlice()
cid.srcPort = tcpHdr.SourcePort()
cid.dstPort = tcpHdr.DestinationPort()
cid.proto = header.IPv6ProtocolNumber
default:
return cid, true
}
return cid, false
}
func (m *processorManager) close() {
if len(m.processors) < 2 {
return
}
for i := range m.processors {
p := &m.processors[i]
p.closeWaker.Assert()
}
}
// wakeReady wakes up all processors that have a packet queued. If there is only
// one processor, the method delivers the packet inline without waking a
// goroutine.
func (m *processorManager) wakeReady() {
for i, ready := range m.ready {
if !ready {
continue
}
p := &m.processors[i]
if len(m.processors) > 1 {
p.packetWaker.Assert()
} else {
p.deliverPackets()
}
m.ready[i] = false
}
}

View File

@@ -458,7 +458,10 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
// CalculateChecksum calculates the checksum of the IPv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return checksum.Checksum(b[:b.HeaderLength()], 0)
// return checksum.Checksum(b[:b.HeaderLength()], 0)
xsum0 := checksum.Checksum(b[:xsum], 0)
xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0)
return xsum0
}
// Encode encodes all the fields of the IPv4 header.
@@ -550,7 +553,8 @@ func (b IPv4) IsChecksumValid() bool {
// same set of octets, including the checksum field. If the result
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
return b.CalculateChecksum() == 0xffff
//return b.CalculateChecksum() == 0xffff
return checksum.Checksum(b[:b.HeaderLength()], 0) == 0xffff
}
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast

View File

@@ -351,14 +351,18 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) {
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
// return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum)
xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum)
return xsum
}
// IsChecksumValid returns true iff the TCP header's checksum is valid.
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
// return b.CalculateChecksum(xsum) == 0xffff
return checksum.Checksum(b[:b.DataOffset()], xsum) == 0xffff
}
// Options returns a slice that holds the unparsed TCP options in the segment.

View File

@@ -113,15 +113,18 @@ func (b UDP) SetLength(length uint16) {
// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
// Calculate the rest of the checksum.\
// return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
xsum := checksum.Checksum(b[:udpChecksum], partialChecksum)
xsum = checksum.Checksum(b[udpChecksum+2:UDPMinimumSize], xsum)
return xsum
}
// IsChecksumValid returns true iff the UDP header's checksum is valid.
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
return checksum.Checksum(b[:UDPMinimumSize], xsum) == 0xffff
}
// Encode encodes all the fields of the UDP header.

View File

@@ -0,0 +1,188 @@
package rawfile
import (
"reflect"
"unsafe"
"golang.org/x/sys/unix"
)
// SizeofIovec is the size of a unix.Iovec in bytes.
const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
// host system call in a single array.
const MaxIovs = 1024
// IovecFromBytes returns a unix.Iovec representing bs.
//
// Preconditions: len(bs) > 0.
func IovecFromBytes(bs []byte) unix.Iovec {
iov := unix.Iovec{
Base: &bs[0],
}
iov.SetLen(len(bs))
return iov
}
func bytesFromIovec(iov unix.Iovec) (bs []byte) {
sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
sh.Data = uintptr(unsafe.Pointer(iov.Base))
sh.Len = int(iov.Len)
sh.Cap = int(iov.Len)
return
}
// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
// max, AppendIovecFromBytes replaces the final iovec in iovs with one that
// also includes the contents of bs. Note that this implies that
// AppendIovecFromBytes is only usable when the returned iovec slice is used as
// the source of a write.
func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
if len(bs) == 0 {
return iovs
}
if len(iovs) < max {
return append(iovs, IovecFromBytes(bs))
}
iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
return iovs
}
type MsgHdrX struct {
Msg unix.Msghdr
DataLen uint32
}
func NonBlockingSendMMsg(fd int, msgHdrs []MsgHdrX) (int, unix.Errno) {
n, _, e := unix.RawSyscall6(unix.SYS_SENDMSG_X, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
return int(n), e
}
const SizeofMsgHdrX = unsafe.Sizeof(MsgHdrX{})
// NonBlockingWriteIovec writes iovec to a file descriptor in a single unix.
// It fails if partial data is written.
func NonBlockingWriteIovec(fd int, iovec []unix.Iovec) unix.Errno {
iovecLen := uintptr(len(iovec))
_, _, e := unix.RawSyscall(unix.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
return e
}
func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
return int(n), 0
}
if e != 0 && e != unix.EWOULDBLOCK {
return 0, e
}
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, e
}
if e != 0 && e != unix.EINTR {
return 0, e
}
}
}
func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MsgHdrX) (int, unix.Errno) {
for {
n, _, e := unix.RawSyscall6(unix.SYS_RECVMSG_X, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
if e == 0 {
return int(n), e
}
if e != 0 && e != unix.EWOULDBLOCK {
return 0, e
}
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
if stopped {
return -1, e
}
if e != 0 && e != unix.EINTR {
return 0, e
}
}
}
func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) {
// Create kqueue
kq, err := unix.Kqueue()
if err != nil {
return false, unix.Errno(err.(unix.Errno))
}
defer unix.Close(kq)
// Prepare kevents for registration
var kevents []unix.Kevent_t
// Always monitor efd for read events
kevents = append(kevents, unix.Kevent_t{
Ident: uint64(efd),
Filter: unix.EVFILT_READ,
Flags: unix.EV_ADD | unix.EV_ENABLE,
})
// Monitor fd based on requested events
// Convert poll events to kqueue filters
if events&unix.POLLIN != 0 {
kevents = append(kevents, unix.Kevent_t{
Ident: uint64(fd),
Filter: unix.EVFILT_READ,
Flags: unix.EV_ADD | unix.EV_ENABLE,
})
}
if events&unix.POLLOUT != 0 {
kevents = append(kevents, unix.Kevent_t{
Ident: uint64(fd),
Filter: unix.EVFILT_WRITE,
Flags: unix.EV_ADD | unix.EV_ENABLE,
})
}
// Register events
_, err = unix.Kevent(kq, kevents, nil, nil)
if err != nil {
return false, unix.Errno(err.(unix.Errno))
}
// Wait for events (blocking)
revents := make([]unix.Kevent_t, len(kevents))
n, err := unix.Kevent(kq, nil, revents, nil)
if err != nil {
return false, unix.Errno(err.(unix.Errno))
}
// Check results
var efdHasData bool
var errno unix.Errno
for i := 0; i < n; i++ {
ev := &revents[i]
if int(ev.Ident) == efd && ev.Filter == unix.EVFILT_READ {
efdHasData = true
}
if int(ev.Ident) == fd {
// Check for errors or EOF
if ev.Flags&unix.EV_EOF != 0 {
errno = unix.ECONNRESET
} else if ev.Flags&unix.EV_ERROR != 0 {
// Extract error from Data field
if ev.Data != 0 {
errno = unix.Errno(ev.Data)
} else {
errno = unix.ECONNRESET
}
}
}
}
return efdHasData, errno
}

View File

@@ -0,0 +1,61 @@
package stopfd
import (
"fmt"
"golang.org/x/sys/unix"
)
type StopFD struct {
ReadFD int
WriteFD int
}
func New() (StopFD, error) {
fds := make([]int, 2)
err := unix.Pipe(fds)
if err != nil {
return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to create pipe: %w", err)
}
if err := unix.SetNonblock(fds[0], true); err != nil {
unix.Close(fds[0])
unix.Close(fds[1])
return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to set read end non-blocking: %w", err)
}
if err := unix.SetNonblock(fds[1], true); err != nil {
unix.Close(fds[0])
unix.Close(fds[1])
return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to set write end non-blocking: %w", err)
}
return StopFD{ReadFD: fds[0], WriteFD: fds[1]}, nil
}
func (sf *StopFD) Stop() {
signal := []byte{1}
if n, err := unix.Write(sf.WriteFD, signal); n != len(signal) || err != nil {
panic(fmt.Sprintf("write(WriteFD) = (%d, %s), want (%d, nil)", n, err, len(signal)))
}
}
func (sf *StopFD) Close() error {
var err1, err2 error
if sf.ReadFD != -1 {
err1 = unix.Close(sf.ReadFD)
sf.ReadFD = -1
}
if sf.WriteFD != -1 {
err2 = unix.Close(sf.WriteFD)
sf.WriteFD = -1
}
if err1 != nil {
return err1
}
return err2
}
func (sf *StopFD) EFD() int {
return sf.ReadFD
}

View File

@@ -39,6 +39,10 @@ func closeAdapter(wintun *Adapter) {
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
err = procWintunCloseAdapter.Find()
if err != nil {
return
}
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {

View File

@@ -30,6 +30,8 @@ type DefaultInterfaceMonitor interface {
AndroidVPNEnabled() bool
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
UnregisterCallback(element *list.Element[DefaultInterfaceUpdateCallback])
RegisterMyInterface(interfaceName string)
MyInterface() string
}
type DefaultInterfaceMonitorOptions struct {

View File

@@ -5,9 +5,9 @@ package tun
import (
"errors"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
@@ -42,11 +42,12 @@ type defaultInterfaceMonitor struct {
androidVPNEnabled bool
noRoute bool
networkMonitor NetworkUpdateMonitor
logger logger.Logger
checkUpdateTimer *time.Timer
element *list.Element[NetworkUpdateCallback]
access sync.Mutex
callbacks list.List[DefaultInterfaceUpdateCallback]
logger logger.Logger
myInterface string
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
@@ -132,3 +133,15 @@ func (m *defaultInterfaceMonitor) emit(defaultInterface *control.Interface, flag
callback(defaultInterface, flags)
}
}
func (m *defaultInterfaceMonitor) RegisterMyInterface(interfaceName string) {
m.access.Lock()
defer m.access.Unlock()
m.myInterface = interfaceName
}
func (m *defaultInterfaceMonitor) MyInterface() string {
m.access.Lock()
defer m.access.Unlock()
return m.myInterface
}

View File

@@ -69,6 +69,7 @@ func (r *autoRedirect) Start() error {
r.androidSu = true
for _, suPath := range []string{
"su",
"/product/bin/su",
"/system/bin/su",
} {
r.suPath, err = exec.LookPath(suPath)

View File

@@ -46,6 +46,11 @@ func (r *autoRedirect) setupNFTables() error {
return err
}
err = r.nftablesCreateLoopbackAddressSets(nft, table)
if err != nil {
return err
}
skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
if !skipOutput {
chainOutput := nft.AddChain(&nftables.Chain{
@@ -61,8 +66,23 @@ func (r *autoRedirect) setupNFTables() error {
return err
}
r.nftablesCreateUnreachable(nft, table, chainOutput)
r.nftablesCreateRedirect(nft, table, chainOutput)
err = r.nftablesCreateRedirect(nft, table, chainOutput)
if err != nil {
return err
}
if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 {
chainOutputRoute := nft.AddChain(&nftables.Chain{
Name: "output_route",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeRoute,
})
err = r.nftablesCreateLoopbackReroute(nft, table, chainOutputRoute)
if err != nil {
return err
}
}
chainOutputUDP := nft.AddChain(&nftables.Chain{
Name: "output_udp_icmp",
Table: table,
@@ -77,7 +97,7 @@ func (r *autoRedirect) setupNFTables() error {
r.nftablesCreateUnreachable(nft, table, chainOutputUDP)
r.nftablesCreateMark(nft, table, chainOutputUDP)
} else {
r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
err = r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
}, &expr.Cmp{
@@ -85,6 +105,9 @@ func (r *autoRedirect) setupNFTables() error {
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
})
if err != nil {
return err
}
}
}
@@ -100,19 +123,46 @@ func (r *autoRedirect) setupNFTables() error {
return err
}
r.nftablesCreateUnreachable(nft, table, chainPreRouting)
r.nftablesCreateRedirect(nft, table, chainPreRouting)
err = r.nftablesCreateRedirect(nft, table, chainPreRouting)
if err != nil {
return err
}
if r.tunOptions.AutoRedirectMarkMode {
r.nftablesCreateMark(nft, table, chainPreRouting)
}
if r.tunOptions.AutoRedirectMarkMode {
if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 {
chainPreRoutingFilter := nft.AddChain(&nftables.Chain{
Name: "prerouting_filter",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 1),
Type: nftables.ChainTypeFilter,
})
err = r.nftablesCreateLoopbackReroute(nft, table, chainPreRoutingFilter)
if err != nil {
return err
}
}
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
Name: "prerouting_udp",
Name: "prerouting_udp_icmp",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
Type: nftables.ChainTypeFilter,
})
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err = nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_UDP}},
{Key: []byte{unix.IPPROTO_ICMP}},
{Key: []byte{unix.IPPROTO_ICMPV6}},
})
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRoutingUDP,
@@ -121,10 +171,11 @@ func (r *autoRedirect) setupNFTables() error {
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{unix.IPPROTO_UDP},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
Invert: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
@@ -272,7 +323,7 @@ func (r *autoRedirect) cleanupNFTables() {
Name: r.tableName,
Family: nftables.TableFamilyINet,
})
common.Must(r.configureOpenWRTFirewall4(nft, true))
_ = r.configureOpenWRTFirewall4(nft, true)
_ = nft.Flush()
_ = nft.CloseLasting()
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/expr"
"github.com/sagernet/sing/common"
"go4.org/netipx"
)
@@ -21,6 +22,20 @@ func nftablesCreateExcludeDestinationIPSet(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
id uint32, name string, family nftables.TableFamily, invert bool,
) {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(
nftablesCreateDestinationIPSetExprs(id, name, family, invert),
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
),
})
}
func nftablesCreateDestinationIPSetExprs(id uint32, name string, family nftables.TableFamily, invert bool) []expr.Any {
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
@@ -53,22 +68,63 @@ func nftablesCreateExcludeDestinationIPSet(
},
)
}
exprs = append(exprs,
&expr.Lookup{
SourceRegister: 1,
SetID: id,
SetName: name,
Invert: invert,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs,
exprs = append(exprs, &expr.Lookup{
SourceRegister: 1,
SetID: id,
SetName: name,
Invert: invert,
})
return exprs
}
func nftablesCreateIPConst(
nft *nftables.Conn, table *nftables.Table, id uint32, name string, family nftables.TableFamily, addressList []netip.Addr,
) (*nftables.Set, error) {
var keyType nftables.SetDatatype
if family == nftables.TableFamilyIPv4 {
keyType = nftables.TypeIPAddr
} else {
keyType = nftables.TypeIP6Addr
}
mySet := &nftables.Set{
Table: table,
ID: id,
Name: name,
KeyType: keyType,
Constant: true,
}
if id == 0 {
mySet.Anonymous = true
}
setElements := common.Map(addressList, func(addr netip.Addr) nftables.SetElement { return nftables.SetElement{Key: addr.AsSlice()} })
if id == 0 {
err := nft.AddSet(mySet, setElements)
if err != nil {
return nil, err
}
return mySet, nil
} else {
err := nft.AddSet(mySet, nil)
if err != nil {
return nil, err
}
}
for len(setElements) > 0 {
toAdd := setElements
if len(toAdd) > 1000 {
toAdd = toAdd[:1000]
}
setElements = setElements[len(toAdd):]
err := nft.SetAddElements(mySet, toAdd)
if err != nil {
return nil, err
}
err = nft.Flush()
if err != nil {
return nil, err
}
}
return mySet, nil
}
func nftablesCreateIPSet(

View File

@@ -117,8 +117,61 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets(
return nil
}
func (r *autoRedirect) nftablesCreateLoopbackAddressSets(
nft *nftables.Conn, table *nftables.Table,
) error {
if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 {
_, err := nftablesCreateIPConst(nft, table, 7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, r.tunOptions.Inet4LoopbackAddress)
if err != nil {
return err
}
}
if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 {
_, err := nftablesCreateIPConst(nft, table, 8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, r.tunOptions.Inet6LoopbackAddress)
if err != nil {
return err
}
}
return nil
}
func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error {
if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput {
if chain.Type == nftables.ChainTypeRoute {
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err := nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_UDP}},
{Key: []byte{unix.IPPROTO_ICMP}},
{Key: []byte{unix.IPPROTO_ICMPV6}},
})
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
Invert: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
@@ -161,6 +214,25 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
}
}
if chain.Hooknum == nftables.ChainHookPrerouting {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
if len(r.tunOptions.IncludeInterface) > 0 {
if len(r.tunOptions.IncludeInterface) > 1 {
includeInterface := &nftables.Set{
@@ -436,44 +508,6 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
}
}
if r.tunOptions.AutoRedirectMarkMode &&
((chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeRoute) ||
(chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeFilter)) {
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err := nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_UDP}},
{Key: []byte{unix.IPPROTO_ICMP}},
{Key: []byte{unix.IPPROTO_ICMPV6}},
})
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
Invert: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
if r.enableIPv4 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false)
}
@@ -527,6 +561,9 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta
SourceRegister: true,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
@@ -534,57 +571,193 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta
func (r *autoRedirect) nftablesCreateRedirect(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
exprs ...expr.Any,
) {
if r.enableIPv4 && !r.enableIPv6 {
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv4)},
})
} else if !r.enableIPv4 && r.enableIPv6 {
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv6)},
})
) error {
exprsRedirect := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Counter{},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
},
&expr.Redir{
RegisterProtoMin: 1,
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs,
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Counter{},
if len(r.tunOptions.Inet4LoopbackAddress) == 0 && len(r.tunOptions.Inet6LoopbackAddress) == 0 {
if r.enableIPv4 && !r.enableIPv6 {
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv4)},
})
} else if !r.enableIPv4 && r.enableIPv6 {
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv6)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs, exprsRedirect...),
})
} else {
if r.enableIPv4 {
exprs4 := exprs
if len(r.tunOptions.Inet4LoopbackAddress) > 0 {
exprs4 = append(exprs4, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, true)...)
} else {
exprs4 = append(exprs4, &expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
}, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv4)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs4, exprsRedirect...),
})
}
if r.enableIPv6 {
exprs6 := exprs
if len(r.tunOptions.Inet6LoopbackAddress) > 0 {
exprs6 = append(exprs6, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, true)...)
} else {
exprs6 = append(exprs6, &expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
}, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv6)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs6, exprsRedirect...),
})
}
}
return nil
}
func (r *autoRedirect) nftablesCreateLoopbackReroute(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
) error {
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
}
var exprs4 []expr.Any
if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 {
exprs4 = append(exprs, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, false)...)
}
var exprs6 []expr.Any
if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 {
exprs6 = append(exprs, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, false)...)
}
var exprsCreateMark []expr.Any
if chain.Hooknum == nftables.ChainHookPrerouting {
exprsCreateMark = []expr.Any{
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
&expr.Redir{
RegisterProtoMin: 1,
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
&expr.Counter{},
}
} else {
exprsCreateMark = []expr.Any{
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
),
})
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Counter{},
}
}
if len(exprs4) > 0 {
exprs4 = append(exprs4, exprsCreateMark...)
}
if len(exprs6) > 0 {
exprs6 = append(exprs6, exprsCreateMark...)
}
if len(exprs4) > 0 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs4,
})
}
if len(exprs6) > 0 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs6,
})
}
return nil
}
func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
@@ -697,6 +870,7 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
Register: 1,
Data: binaryutil.BigEndian.PutUint16(53),
},
&expr.Counter{},
&expr.Immediate{
Register: 1,
Data: dnsServer.AsSlice(),
@@ -706,7 +880,6 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
Family: uint32(family),
RegAddrMin: 1,
},
&expr.Counter{},
)
nft.AddRule(&nftables.Rule{
Table: table,
@@ -742,9 +915,7 @@ func (r *autoRedirect) nftablesCreateUnreachable(
Data: []byte{uint8(nfProto)},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
&expr.Reject{},
},
})
}

View File

@@ -3,101 +3,50 @@
package tun
import (
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/expr"
"os"
"os/exec"
"golang.org/x/exp/slices"
"github.com/sagernet/nftables"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/shell"
)
func (r *autoRedirect) configureOpenWRTFirewall4(nft *nftables.Conn, cleanup bool) error {
tableFW4, err := nft.ListTableOfFamily("fw4", nftables.TableFamilyINet)
_, err := nft.ListTableOfFamily("fw4", nftables.TableFamilyINet)
if err != nil {
return nil
}
if !cleanup {
ruleIif := &nftables.Rule{
Table: tableFW4,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
ruleOif := &nftables.Rule{
Table: tableFW4,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
chainForward := &nftables.Chain{
Name: "forward",
}
ruleIif.Chain = chainForward
ruleOif.Chain = chainForward
nft.InsertRule(ruleOif)
nft.InsertRule(ruleIif)
chainInput := &nftables.Chain{
Name: "input",
}
ruleIif.Chain = chainInput
ruleOif.Chain = chainInput
nft.InsertRule(ruleOif)
nft.InsertRule(ruleIif)
fw4Path, err := exec.LookPath("fw4")
if err != nil {
return nil
}
for _, chainName := range []string{"input", "forward"} {
var rules []*nftables.Rule
rules, err = nft.GetRules(tableFW4, &nftables.Chain{
Name: chainName,
})
rulePath := "/etc/nftables.d/0-" + r.tableName + "-auto-redirect.nft"
if !cleanup {
err = os.WriteFile(rulePath, []byte(`chain input {
type filter hook input priority filter; policy accept;
iifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
oifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
}
chain forward {
type filter hook forward priority filter; policy accept;
iifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
oifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
}
`), 0o644)
if err != nil {
return err
return E.Cause(err, "write fw4 rules")
}
for _, rule := range rules {
if len(rule.Exprs) != 4 {
continue
}
exprMeta, isMeta := rule.Exprs[0].(*expr.Meta)
if !isMeta {
continue
}
if exprMeta.Key != expr.MetaKeyIIFNAME && exprMeta.Key != expr.MetaKeyOIFNAME {
continue
}
exprCmp, isCmp := rule.Exprs[1].(*expr.Cmp)
if !isCmp {
continue
}
if !slices.Equal(exprCmp.Data, nftablesIfname(r.tunOptions.Name)) {
continue
}
err = nft.DelRule(rule)
if err != nil {
return err
}
} else if _, err = os.Stat(rulePath); os.IsNotExist(err) {
return nil
} else {
err = os.Remove(rulePath)
if err != nil {
return E.Cause(err, "clean fw4 rules")
}
}
output, err := shell.Exec(fw4Path, "reload").Read()
if err != nil {
return E.Extend(E.Cause(err, "reload fw4 rules"), output)
}
return nil
}

View File

@@ -7,9 +7,9 @@ import (
"errors"
"net"
"net/netip"
"sync/atomic"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"

View File

@@ -26,19 +26,22 @@ const WithGVisor = true
const DefaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
tun GVisorTun
udpTimeout time.Duration
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
ctx context.Context
tun GVisorTun
inet4LoopbackAddress []netip.Addr
inet6LoopbackAddress []netip.Addr
udpTimeout time.Duration
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
}
type GVisorTun interface {
Tun
NewEndpoint() (stack.LinkEndpoint, error)
WritePacket(pkt *stack.PacketBuffer) (int, error)
NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error)
}
func NewGVisor(
@@ -50,27 +53,29 @@ func NewGVisor(
}
gStack := &GVisor{
ctx: options.Context,
tun: gTun,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
ctx: options.Context,
tun: gTun,
inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
return gStack, nil
}
func (t *GVisor) Start() error {
linkEndpoint, err := t.tun.NewEndpoint()
linkEndpoint, nicOptions, err := t.tun.NewEndpoint()
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := NewGVisorStack(linkEndpoint)
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions)
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
@@ -106,6 +111,10 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
}
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
return NewGVisorStackWithOptions(ep, stack.NICOptions{})
}
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
@@ -118,7 +127,7 @@ func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
icmp.NewProtocol6,
},
})
err := ipStack.CreateNIC(DefaultNIC, ep)
err := ipStack.CreateNICWithOptions(DefaultNIC, ep, opts)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}

View File

@@ -8,8 +8,6 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
@@ -17,7 +15,7 @@ var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
type LinkEndpointFilter struct {
stack.LinkEndpoint
BroadcastAddress netip.Addr
Writer N.VectorisedWriter
Writer GVisorTun
}
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
@@ -29,7 +27,7 @@ var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
type networkDispatcherFilter struct {
stack.NetworkDispatcher
broadcastAddress netip.Addr
writer N.VectorisedWriter
writer GVisorTun
}
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
@@ -49,7 +47,7 @@ func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkPro
}
destination := AddrFromAddress(network.DestinationAddress())
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
w.writer.WritePacket(pkt)
return
}
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)

View File

@@ -4,8 +4,10 @@ package tun
import (
"context"
"errors"
"net"
"os"
"sync"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
@@ -17,19 +19,25 @@ import (
)
type gLazyConn struct {
tcpConn *gonet.TCPConn
parentCtx context.Context
stack *stack.Stack
request *tcp.ForwarderRequest
localAddr net.Addr
remoteAddr net.Addr
handshakeDone bool
handshakeErr error
tcpConn *gonet.TCPConn
parentCtx context.Context
stack *stack.Stack
request *tcp.ForwarderRequest
localAddr net.Addr
remoteAddr net.Addr
handshakeAccess sync.Mutex
handshakeDone bool
handshakeErr error
}
func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
if c.handshakeDone {
return nil
return c.handshakeErr
}
c.handshakeAccess.Lock()
defer c.handshakeAccess.Unlock()
if c.handshakeDone {
return c.handshakeErr
}
defer func() {
c.handshakeDone = true
@@ -67,7 +75,12 @@ func (c *gLazyConn) HandshakeFailure(err error) error {
if c.handshakeDone {
return os.ErrInvalid
}
c.request.Complete(err != ErrDrop)
c.handshakeAccess.Lock()
defer c.handshakeAccess.Unlock()
if c.handshakeDone {
return os.ErrInvalid
}
c.request.Complete(!errors.Is(err, ErrDrop))
c.handshakeDone = true
c.handshakeErr = err
return nil
@@ -78,25 +91,17 @@ func (c *gLazyConn) HandshakeSuccess() error {
}
func (c *gLazyConn) Read(b []byte) (n int, err error) {
if !c.handshakeDone {
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
} else if c.handshakeErr != nil {
return 0, c.handshakeErr
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
return c.tcpConn.Read(b)
}
func (c *gLazyConn) Write(b []byte) (n int, err error) {
if !c.handshakeDone {
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
} else if c.handshakeErr != nil {
return 0, c.handshakeErr
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
return c.tcpConn.Write(b)
}
@@ -110,46 +115,41 @@ func (c *gLazyConn) RemoteAddr() net.Addr {
}
func (c *gLazyConn) SetDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
return c.tcpConn.SetDeadline(t)
}
func (c *gLazyConn) SetReadDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
return c.tcpConn.SetReadDeadline(t)
}
func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
return c.tcpConn.SetWriteDeadline(t)
}
func (c *gLazyConn) Close() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
c.handshakeAccess.Lock()
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
c.handshakeDone = true
return nil
} else if c.handshakeErr != nil {
return nil
}
c.handshakeAccess.Unlock()
} else if c.handshakeErr != nil {
return nil
}
@@ -158,9 +158,16 @@ func (c *gLazyConn) Close() error {
func (c *gLazyConn) CloseRead() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
c.handshakeAccess.Lock()
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
c.handshakeDone = true
return nil
} else if c.handshakeErr != nil {
return nil
}
c.handshakeAccess.Unlock()
} else if c.handshakeErr != nil {
return nil
}
@@ -169,9 +176,16 @@ func (c *gLazyConn) CloseRead() error {
func (c *gLazyConn) CloseWrite() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
c.handshakeAccess.Lock()
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
c.handshakeDone = true
return nil
} else if c.handshakeErr != nil {
return nil
}
c.handshakeAccess.Unlock()
} else if c.handshakeErr != nil {
return nil
}
@@ -179,10 +193,14 @@ func (c *gLazyConn) CloseWrite() error {
}
func (c *gLazyConn) ReaderReplaceable() bool {
c.handshakeAccess.Lock()
defer c.handshakeAccess.Unlock()
return c.handshakeDone && c.handshakeErr == nil
}
func (c *gLazyConn) WriterReplaceable() bool {
c.handshakeAccess.Lock()
defer c.handshakeAccess.Unlock()
return c.handshakeDone && c.handshakeErr == nil
}

View File

@@ -4,31 +4,75 @@ package tun
import (
"context"
"errors"
"net/netip"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing/common"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TCPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
forwarder *tcp.Forwarder
ctx context.Context
stack *stack.Stack
handler Handler
inet4LoopbackAddress []tcpip.Address
inet6LoopbackAddress []tcpip.Address
tun GVisorTun
forwarder *tcp.Forwarder
}
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
return NewTCPForwarderWithLoopback(ctx, stack, handler, nil, nil, nil)
}
func NewTCPForwarderWithLoopback(ctx context.Context, stack *stack.Stack, handler Handler, inet4LoopbackAddress []netip.Addr, inet6LoopbackAddress []netip.Addr, tun GVisorTun) *TCPForwarder {
forwarder := &TCPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
ctx: ctx,
stack: stack,
handler: handler,
inet4LoopbackAddress: common.Map(inet4LoopbackAddress, AddressFromAddr),
inet6LoopbackAddress: common.Map(inet6LoopbackAddress, AddressFromAddr),
tun: tun,
}
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder
}
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
for _, inet4LoopbackAddress := range f.inet4LoopbackAddress {
if id.LocalAddress == inet4LoopbackAddress {
ipHdr := pkt.Network().(header.IPv4)
ipHdr.SetDestinationAddressWithChecksumUpdate(ipHdr.SourceAddress())
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
f.tun.WritePacket(pkt)
return true
}
}
for _, inet6LoopbackAddress := range f.inet6LoopbackAddress {
if id.LocalAddress == inet6LoopbackAddress {
ipHdr := pkt.Network().(header.IPv6)
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
ipHdr.SetSourceAddress(inet6LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
f.tun.WritePacket(pkt)
return true
}
}
return f.forwarder.HandlePacket(id, pkt)
}
@@ -37,7 +81,7 @@ func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
r.Complete(!errors.Is(pErr, ErrDrop))
return
}
conn := &gLazyConn{

View File

@@ -4,6 +4,7 @@ package tun
import (
"context"
"errors"
"math"
"net/netip"
"os"
@@ -59,7 +60,7 @@ func rangeIterate(r stack.Range, fn func(*buffer.View))
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
if pErr != ErrDrop {
if !errors.Is(pErr, ErrDrop) {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))
}
return false, nil, nil, nil

View File

@@ -10,12 +10,13 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
)
type Mixed struct {
*System
tun GVisorTun
stack *stack.Stack
endpoint *channel.Endpoint
}
@@ -29,6 +30,7 @@ func NewMixed(
}
return &Mixed{
System: system.(*System),
tun: system.(*System).tun.(GVisorTun),
}, nil
}
@@ -72,10 +74,14 @@ func (m *Mixed) tunLoop() {
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
m.batchLoop(linuxTUN, batchSize)
m.batchLoopLinux(linuxTUN, batchSize)
return
}
}
if darwinTUN, isDarwinTUN := m.tun.(DarwinTUN); isDarwinTUN && m.multiPendingPackets {
m.batchLoopDarwin(darwinTUN)
return
}
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer)
@@ -119,12 +125,12 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
}
}
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
func (m *Mixed) batchLoopLinux(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
packetBuffers[i] = make([]byte, m.mtu+PacketOffset+m.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
@@ -158,6 +164,40 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
}
}
func (m *Mixed) batchLoopDarwin(darwinTUN DarwinTUN) {
var writeBuffers []*buf.Buffer
for {
buffers, err := darwinTUN.BatchRead()
if err != nil {
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "batch read packet"))
}
if len(buffers) == 0 {
continue
}
writeBuffers = writeBuffers[:0]
for _, buffer := range buffers {
packetSize := buffer.Len()
if packetSize < header.IPv4MinimumSize {
continue
}
if m.processPacket(buffer.Bytes()) {
writeBuffers = append(writeBuffers, buffer)
} else {
buffer.Release()
}
}
if len(writeBuffers) > 0 {
err = darwinTUN.BatchWrite(writeBuffers)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
}
}
}
func (m *Mixed) processPacket(packet []byte) bool {
var (
writeBack bool
@@ -226,11 +266,11 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
func (m *Mixed) packetLoop() {
for {
packet := m.endpoint.ReadContext(m.ctx)
if packet == nil {
pkt := m.endpoint.ReadContext(m.ctx)
if pkt == nil {
break
}
bufio.WriteVectorised(m.tun, packet.AsSlices())
packet.DecRef()
m.tun.WritePacket(pkt)
pkt.DecRef()
}
}

View File

@@ -2,6 +2,7 @@ package tun
import (
"context"
"errors"
"net"
"net/netip"
"syscall"
@@ -22,30 +23,33 @@ import (
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
type System struct {
ctx context.Context
tun Tun
tunName string
mtu int
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
inet6Prefixes []netip.Prefix
inet4ServerAddress netip.Addr
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
udpTimeout time.Duration
tcpListener net.Listener
tcpListener6 net.Listener
tcpPort uint16
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
ctx context.Context
tun Tun
tunName string
mtu int
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
inet6Prefixes []netip.Prefix
inet4ServerAddress netip.Addr
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
inet4LoopbackAddress []netip.Addr
inet6LoopbackAddress []netip.Addr
udpTimeout time.Duration
tcpListener net.Listener
tcpListener6 net.Listener
tcpPort uint16
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
multiPendingPackets bool
}
type Session struct {
@@ -57,18 +61,21 @@ type Session struct {
func NewSystem(options StackOptions) (Stack, error) {
stack := &System{
ctx: options.Context,
tun: options.Tun,
tunName: options.TunOptions.Name,
mtu: int(options.TunOptions.MTU),
udpTimeout: options.UDPTimeout,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.TunOptions.Inet4Address,
inet6Prefixes: options.TunOptions.Inet6Address,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
ctx: options.Context,
tun: options.Tun,
tunName: options.TunOptions.Name,
mtu: int(options.TunOptions.MTU),
inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
udpTimeout: options.UDPTimeout,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.TunOptions.Inet4Address,
inet6Prefixes: options.TunOptions.Inet6Address,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
multiPendingPackets: options.TunOptions.EXP_MultiPendingPackets,
}
if len(options.TunOptions.Inet4Address) > 0 {
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
@@ -165,10 +172,14 @@ func (s *System) tunLoop() {
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
s.batchLoop(linuxTUN, batchSize)
s.batchLoopLinux(linuxTUN, batchSize)
return
}
}
if darwinTUN, isDarwinTUN := s.tun.(DarwinTUN); isDarwinTUN && s.multiPendingPackets {
s.batchLoopDarwin(darwinTUN)
return
}
packetBuffer := make([]byte, s.mtu+PacketOffset)
for {
n, err := s.tun.Read(packetBuffer)
@@ -212,7 +223,7 @@ func (s *System) wintunLoop(winTun WinTun) {
}
}
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
func (s *System) batchLoopLinux(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
@@ -251,6 +262,40 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
}
}
func (s *System) batchLoopDarwin(darwinTUN DarwinTUN) {
var writeBuffers []*buf.Buffer
for {
buffers, err := darwinTUN.BatchRead()
if err != nil {
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "batch read packet"))
}
if len(buffers) == 0 {
continue
}
writeBuffers = writeBuffers[:0]
for _, buffer := range buffers {
packetSize := buffer.Len()
if packetSize < header.IPv4MinimumSize {
continue
}
if s.processPacket(buffer.Bytes()) {
writeBuffers = append(writeBuffers, buffer)
} else {
buffer.Release()
}
}
if len(writeBuffers) > 0 {
err = darwinTUN.BatchWrite(writeBuffers)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
}
}
}
func (s *System) processPacket(packet []byte) bool {
var (
writeBack bool
@@ -352,28 +397,37 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port())
} else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
if err == ErrDrop {
return false, nil
} else {
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
var loopback bool
for _, inet4LoopbackAddress := range s.inet4LoopbackAddress {
if destination.Addr() == inet4LoopbackAddress {
ipHdr.SetDestinationAddr(ipHdr.SourceAddr())
ipHdr.SetSourceAddr(inet4LoopbackAddress)
loopback = true
break
}
}
ipHdr.SetSourceAddr(s.inet4Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort)
if !loopback {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
if errors.Is(err, ErrDrop) {
return false, nil
} else {
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet4Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort)
}
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
tcpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return true, nil
}
@@ -414,7 +468,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
@@ -439,21 +492,31 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port())
} else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
if err == ErrDrop {
return false, nil
} else {
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
var loopback bool
for _, inet6LoopbackAddress := range s.inet6LoopbackAddress {
if destination.Addr() == inet6LoopbackAddress {
ipHdr.SetDestinationAddr(ipHdr.SourceAddr())
ipHdr.SetSourceAddr(inet6LoopbackAddress)
loopback = true
break
}
}
ipHdr.SetSourceAddr(s.inet6Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort6)
if !loopback {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
if errors.Is(err, ErrDrop) {
return false, nil
} else {
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet6Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort6)
}
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
@@ -536,7 +599,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
if pErr != ErrDrop {
if !errors.Is(pErr, ErrDrop) {
if source.IsIPv4() {
ipHdr := userData.(header.IPv4)
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
@@ -584,8 +647,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
sourceAddress := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
ipHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return nil
}
@@ -619,7 +681,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
icmpHdr := header.ICMPv4(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
@@ -712,14 +774,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
udpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
@@ -753,7 +813,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))

View File

@@ -11,6 +11,7 @@ import (
)
type TCPNat struct {
timeout time.Duration
portIndex uint16
portAccess sync.RWMutex
addrAccess sync.RWMutex
@@ -19,6 +20,7 @@ type TCPNat struct {
}
type TCPSession struct {
sync.Mutex
Source netip.AddrPort
Destination netip.AddrPort
LastActive time.Time
@@ -26,38 +28,41 @@ type TCPSession struct {
func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
natMap := &TCPNat{
timeout: timeout,
portIndex: 10000,
addrMap: make(map[netip.AddrPort]uint16),
portMap: make(map[uint16]*TCPSession),
}
go natMap.loopCheckTimeout(ctx, timeout)
go natMap.loopCheckTimeout(ctx)
return natMap
}
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
ticker := time.NewTicker(timeout)
func (n *TCPNat) loopCheckTimeout(ctx context.Context) {
ticker := time.NewTicker(n.timeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n.checkTimeout(timeout)
n.checkTimeout()
case <-ctx.Done():
return
}
}
}
func (n *TCPNat) checkTimeout(timeout time.Duration) {
func (n *TCPNat) checkTimeout() {
now := time.Now()
n.portAccess.Lock()
defer n.portAccess.Unlock()
n.addrAccess.Lock()
defer n.addrAccess.Unlock()
for natPort, session := range n.portMap {
if now.Sub(session.LastActive) > timeout {
session.Lock()
if now.Sub(session.LastActive) > n.timeout {
delete(n.addrMap, session.Source)
delete(n.portMap, natPort)
}
session.Unlock()
}
}
@@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
session := n.portMap[port]
n.portAccess.RUnlock()
if session != nil {
session.LastActive = time.Now()
session.Lock()
if time.Since(session.LastActive) > time.Second {
session.LastActive = time.Now()
}
session.Unlock()
}
return session
}

17
tun.go
View File

@@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
@@ -24,7 +25,6 @@ type Handler interface {
type Tun interface {
io.ReadWriter
N.VectorisedWriter
Name() (string, error)
Start() error
Close() error
@@ -45,6 +45,12 @@ type LinuxTUN interface {
TXChecksumOffload() bool
}
type DarwinTUN interface {
Tun
BatchRead() ([]*buf.Buffer, error)
BatchWrite(buffers []*buf.Buffer) error
}
const (
DefaultIPRoute2TableIndex = 2022
DefaultIPRoute2RuleIndex = 9000
@@ -66,6 +72,8 @@ type Options struct {
AutoRedirectMarkMode bool
AutoRedirectInputMark uint32
AutoRedirectOutputMark uint32
Inet4LoopbackAddress []netip.Addr
Inet6LoopbackAddress []netip.Addr
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix
@@ -88,6 +96,13 @@ type Options struct {
// For library usages.
EXP_DisableDNSHijack bool
// For gvisor stack, it should be enabled when MTU is less than 32768; otherwise it should be less than or equal to 8192.
// The above condition is just an estimate and not exact, calculated on M4 pro.
EXP_MultiPendingPackets bool
// Will cause the darwin network to die, do not use.
EXP_SendMsgX bool
}
func (o *Options) Inet4GatewayAddr() netip.Addr {

View File

@@ -10,26 +10,74 @@ import (
"unsafe"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
"github.com/sagernet/sing-tun/internal/stopfd_darwin"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/shell"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
)
var _ DarwinTUN = (*NativeTun)(nil)
const PacketOffset = 4
type NativeTun struct {
tunFile *os.File
tunWriter N.VectorisedWriter
options Options
inet4Address [4]byte
inet6Address [16]byte
routeSet bool
tunFd int
tunFile *os.File
batchSize int
iovecs []iovecBuffer
iovecsOutput []iovecBuffer
iovecsOutputDefault []unix.Iovec
msgHdrs []rawfile.MsgHdrX
msgHdrsOutput []rawfile.MsgHdrX
buffers []*buf.Buffer
stopFd stopfd.StopFD
options Options
inet4Address [4]byte
inet6Address [16]byte
routeSet bool
sendMsgX bool
}
type iovecBuffer struct {
mtu int
buffer *buf.Buffer
iovecs []unix.Iovec
}
func newIovecBuffer(mtu int) iovecBuffer {
return iovecBuffer{
mtu: mtu,
iovecs: make([]unix.Iovec, 2),
}
}
func (b *iovecBuffer) nextIovecs() []unix.Iovec {
if b.iovecs[0].Len == 0 {
headBuffer := make([]byte, PacketOffset)
b.iovecs[0].Base = &headBuffer[0]
b.iovecs[0].SetLen(PacketOffset)
}
if b.buffer == nil {
b.buffer = buf.NewSize(b.mtu)
b.iovecs[1] = b.buffer.Iovec(b.buffer.Cap())
}
return b.iovecs
}
func (b *iovecBuffer) nextIovecsOutput(buffer *buf.Buffer) []unix.Iovec {
switch header.IPVersion(buffer.Bytes()) {
case header.IPv4Version:
b.iovecs[0] = packetHeaderVec4
case header.IPv6Version:
b.iovecs[0] = packetHeaderVec6
}
b.iovecs[1] = buffer.Iovec(buffer.Len())
return b.iovecs
}
func (t *NativeTun) Name() (string, error) {
@@ -42,6 +90,7 @@ func (t *NativeTun) Name() (string, error) {
func New(options Options) (Tun, error) {
var tunFd int
batchSize := ((512 * 1024) / int(options.MTU)) + 1
if options.FileDescriptor == 0 {
ifIndex := -1
_, err := fmt.Sscanf(options.Name, "utun%d", &ifIndex)
@@ -54,18 +103,38 @@ func New(options Options) (Tun, error) {
return nil, err
}
err = configure(tunFd, ifIndex, options.Name, options)
err = create(tunFd, ifIndex, options.Name, options)
if err != nil {
unix.Close(tunFd)
return nil, err
}
err = configure(tunFd, options.EXP_MultiPendingPackets, batchSize)
if err != nil {
unix.Close(tunFd)
return nil, err
}
} else {
tunFd = options.FileDescriptor
err := configure(tunFd, options.EXP_MultiPendingPackets, batchSize)
if err != nil {
return nil, err
}
}
nativeTun := &NativeTun{
tunFile: os.NewFile(uintptr(tunFd), "utun"),
options: options,
tunFd: tunFd,
tunFile: os.NewFile(uintptr(tunFd), "utun"),
options: options,
batchSize: batchSize,
iovecs: make([]iovecBuffer, batchSize),
iovecsOutput: make([]iovecBuffer, batchSize),
msgHdrs: make([]rawfile.MsgHdrX, batchSize),
msgHdrsOutput: make([]rawfile.MsgHdrX, batchSize),
stopFd: common.Must1(stopfd.New()),
sendMsgX: options.EXP_SendMsgX,
}
for i := 0; i < batchSize; i++ {
nativeTun.iovecs[i] = newIovecBuffer(int(options.MTU))
nativeTun.iovecsOutput[i] = newIovecBuffer(int(options.MTU))
}
if len(options.Inet4Address) > 0 {
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
@@ -73,21 +142,20 @@ func New(options Options) (Tun, error) {
if len(options.Inet6Address) > 0 {
nativeTun.inet6Address = options.Inet6Address[0].Addr().As16()
}
var ok bool
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
if !ok {
panic("create vectorised writer")
}
return nativeTun, nil
}
func (t *NativeTun) Start() error {
t.options.InterfaceMonitor.RegisterMyInterface(t.options.Name)
return t.setRoutes()
}
func (t *NativeTun) Close() error {
defer flushDNSCache()
return E.Errors(t.unsetRoutes(), t.tunFile.Close())
t.stopFd.Stop()
err := E.Errors(t.unsetRoutes(), t.tunFile.Close())
t.stopFd.Close()
return err
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
@@ -99,19 +167,15 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
}
var (
packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET}
packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6}
packetHeader4 = []byte{0x00, 0x00, 0x00, unix.AF_INET}
packetHeader6 = []byte{0x00, 0x00, 0x00, unix.AF_INET6}
packetHeaderVec4 = unix.Iovec{Base: &packetHeader4[0]}
packetHeaderVec6 = unix.Iovec{Base: &packetHeader6[0]}
)
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
switch header.IPVersion(buffers[0].Bytes()) {
case header.IPv4Version:
packetHeader = packetHeader4[:]
case header.IPv6Version:
packetHeader = packetHeader6[:]
}
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
func init() {
packetHeaderVec4.SetLen(4)
packetHeaderVec6.SetLen(4)
}
const utunControlName = "com.apple.net.utun_control"
@@ -146,7 +210,7 @@ type addrLifetime6 struct {
Pltime uint32
}
func configure(tunFd int, ifIndex int, name string, options Options) error {
func create(tunFd int, ifIndex int, name string, options Options) error {
ctlInfo := &unix.CtlInfo{}
copy(ctlInfo.Name[:], utunControlName)
err := unix.IoctlCtlInfo(tunFd, ctlInfo)
@@ -162,11 +226,6 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
return os.NewSyscallError("Connect", err)
}
err = unix.SetNonblock(tunFd, true)
if err != nil {
return os.NewSyscallError("SetNonblock", err)
}
err = useSocket(unix.AF_INET, unix.SOCK_DGRAM, 0, func(socketFd int) error {
var ifr unix.IfreqMTU
copy(ifr.Name[:], name)
@@ -258,6 +317,90 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
return nil
}
func configure(tunFd int, multiPendingPackets bool, batchSize int) error {
err := unix.SetNonblock(tunFd, true)
if err != nil {
return os.NewSyscallError("SetNonblock", err)
}
if multiPendingPackets {
const UTUN_OPT_MAX_PENDING_PACKETS = 16
err = unix.SetsockoptInt(tunFd, 2, UTUN_OPT_MAX_PENDING_PACKETS, batchSize)
if err != nil {
return os.NewSyscallError("SetsockoptInt UTUN_OPT_MAX_PENDING_PACKETS", err)
}
}
return nil
}
func (t *NativeTun) BatchRead() ([]*buf.Buffer, error) {
for i := 0; i < t.batchSize; i++ {
iovecs := t.iovecs[i].nextIovecs()
// Cannot clear only the length field. Older versions of the darwin kernel will check whether other data is empty.
// https://github.com/Darm64/XNU/blob/xnu-2782.40.9/bsd/kern/uipc_syscalls.c#L2026-L2048
t.msgHdrs[i] = rawfile.MsgHdrX{}
t.msgHdrs[i].Msg.Iov = &iovecs[0]
t.msgHdrs[i].Msg.Iovlen = 2
}
n, errno := rawfile.BlockingRecvMMsgUntilStopped(t.stopFd.ReadFD, t.tunFd, t.msgHdrs)
if errno != 0 {
for k := 0; k < n; k++ {
t.iovecs[k].buffer.Release()
t.iovecs[k].buffer = nil
}
t.buffers = t.buffers[:0]
return nil, errno
}
if n < 0 {
return nil, os.ErrClosed
}
if n < 1 {
return nil, nil
}
buffers := t.buffers
for k := 0; k < n; k++ {
buffer := t.iovecs[k].buffer
t.iovecs[k].buffer = nil
buffer.Truncate(int(t.msgHdrs[k].DataLen) - PacketOffset)
buffers = append(buffers, buffer)
}
t.buffers = buffers[:0]
return buffers, nil
}
func (t *NativeTun) BatchWrite(buffers []*buf.Buffer) error {
if !t.sendMsgX {
for i, buffer := range buffers {
t.iovecsOutput[i].nextIovecsOutput(buffer)
}
for i := range buffers {
errno := rawfile.NonBlockingWriteIovec(t.tunFd, t.iovecsOutput[i].iovecs)
if errno != 0 {
return errno
}
}
} else {
for i, buffer := range buffers {
iovecs := t.iovecsOutput[i].nextIovecsOutput(buffer)
t.msgHdrsOutput[i] = rawfile.MsgHdrX{}
t.msgHdrsOutput[i].Msg.Iov = &iovecs[0]
t.msgHdrsOutput[i].Msg.Iovlen = 2
}
var n int
for n != len(buffers) {
sent, errno := rawfile.NonBlockingSendMMsg(t.tunFd, t.msgHdrsOutput[n:len(buffers)])
if errno != 0 {
return errno
}
n += sent
}
}
return nil
}
func (t *NativeTun) TXChecksumOffload() bool {
return false
}
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
err := t.unsetRoutes()
if err != nil {

View File

@@ -3,132 +3,57 @@
package tun
import (
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/link/qdisc/fifo"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing-tun/internal/fdbased_darwin"
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
"golang.org/x/sys/unix"
)
var _ GVisorTun = (*NativeTun)(nil)
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
return &DarwinEndpoint{tun: t}, nil
}
var _ stack.LinkEndpoint = (*DarwinEndpoint)(nil)
type DarwinEndpoint struct {
tun *NativeTun
dispatcher stack.NetworkDispatcher
}
func (e *DarwinEndpoint) MTU() uint32 {
return e.tun.options.MTU
}
func (e *DarwinEndpoint) SetMTU(mtu uint32) {
}
func (e *DarwinEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (e *DarwinEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
func (e *DarwinEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
func (e *DarwinEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityRXChecksumOffload
}
func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
if dispatcher == nil && e.dispatcher != nil {
e.dispatcher = nil
return
func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) {
iovecs := t.iovecsOutputDefault
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
iovecs = append(iovecs, packetHeaderVec4)
} else {
iovecs = append(iovecs, packetHeaderVec6)
}
if dispatcher != nil && e.dispatcher == nil {
e.dispatcher = dispatcher
go e.dispatchLoop()
var dataLen int
for _, packetSlice := range pkt.AsSlices() {
dataLen += len(packetSlice)
iovec := unix.Iovec{
Base: &packetSlice[0],
}
iovec.SetLen(len(packetSlice))
iovecs = append(iovecs, iovec)
}
if cap(iovecs) > cap(t.iovecsOutputDefault) {
t.iovecsOutputDefault = iovecs[:0]
}
errno := rawfile.NonBlockingWriteIovec(t.tunFd, iovecs)
if errno == 0 {
return dataLen, nil
} else {
return 0, errno
}
}
func (e *DarwinEndpoint) dispatchLoop() {
packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset)
for {
n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil {
break
}
packet := packetBuffer[PacketOffset:n]
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
if header.IPv4(packet).DestinationAddress().As4() == e.tun.inet4Address {
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
if header.IPv6(packet).DestinationAddress().As16() == e.tun.inet6Address {
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
default:
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packetBuffer[4:n]),
IsForwardedPacket: true,
})
pkt.NetworkProtocolNumber = networkProtocol
dispatcher := e.dispatcher
if dispatcher == nil {
pkt.DecRef()
return
}
dispatcher.DeliverNetworkPacket(networkProtocol, pkt)
pkt.DecRef()
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) {
ep, err := fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
RXChecksumOffload: true,
PacketDispatchMode: fdbased.RecvMMsg,
MultiPendingPackets: t.options.EXP_MultiPendingPackets,
SendMsgX: t.options.EXP_SendMsgX,
})
if err != nil {
return nil, stack.NICOptions{}, err
}
}
func (e *DarwinEndpoint) IsAttached() bool {
return e.dispatcher != nil
}
func (e *DarwinEndpoint) Wait() {
}
func (e *DarwinEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *DarwinEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (e *DarwinEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {
var n int
for _, packet := range packetBufferList.AsSlice() {
_, err := bufio.WriteVectorised(e.tun, packet.AsSlices())
if err != nil {
return n, &tcpip.ErrAborted{}
}
n++
}
return n, nil
}
func (e *DarwinEndpoint) Close() {
}
func (e *DarwinEndpoint) SetOnCloseAction(f func()) {
return ep, stack.NICOptions{
QDisc: fifo.New(ep, 1, 1000),
}, nil
}

View File

@@ -18,10 +18,8 @@ import (
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing/common/x/list"
@@ -32,22 +30,22 @@ import (
var _ LinuxTUN = (*NativeTun)(nil)
type NativeTun struct {
tunFd int
tunFile *os.File
tunWriter N.VectorisedWriter
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
options Options
ruleIndex6 []int
readAccess sync.Mutex
writeAccess sync.Mutex
vnetHdr bool
writeBuffer []byte
gsoToWrite []int
tcpGROTable *tcpGROTable
udpGroAccess sync.Mutex
udpGROTable *udpGROTable
gro groDisablementFlags
txChecksumOffload bool
tunFd int
tunFile *os.File
iovecsOutputDefault []unix.Iovec
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
options Options
ruleIndex6 []int
readAccess sync.Mutex
writeAccess sync.Mutex
vnetHdr bool
writeBuffer []byte
gsoToWrite []int
tcpGROTable *tcpGROTable
udpGroAccess sync.Mutex
udpGROTable *udpGROTable
gro groDisablementFlags
txChecksumOffload bool
}
func New(options Options) (Tun, error) {
@@ -77,11 +75,6 @@ func New(options Options) (Tun, error) {
options: options,
}
}
var ok bool
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
if !ok {
panic("create vectorised writer")
}
return nativeTun, nil
}
@@ -137,7 +130,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
for _, address := range t.options.Inet4Address {
addr4, _ := netlink.ParseAddr(address.String())
err = netlink.AddrAdd(tunLink, addr4)
if err != nil {
if err != nil && !errors.Is(err, unix.EEXIST) {
return err
}
}
@@ -146,7 +139,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
for _, address := range t.options.Inet6Address {
addr6, _ := netlink.ParseAddr(address.String())
err = netlink.AddrAdd(tunLink, addr6)
if err != nil {
if err != nil && !errors.Is(err, unix.EEXIST) {
return err
}
}
@@ -264,7 +257,7 @@ func (t *NativeTun) Start() error {
if t.options.FileDescriptor != 0 {
return nil
}
t.options.InterfaceMonitor.RegisterMyInterface(t.options.Name)
tunLink, err := netlink.LinkByName(t.options.Name)
if err != nil {
return err
@@ -322,6 +315,7 @@ func (t *NativeTun) Close() error {
if t.interfaceCallback != nil {
t.options.InterfaceMonitor.UnregisterCallback(t.interfaceCallback)
}
t.unsetAddresses()
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
}
@@ -402,20 +396,6 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
return t.tunFile.Write(p)
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
if t.vnetHdr {
n := buf.LenMulti(buffers)
buffer := buf.NewSize(virtioNetHdrLen + n)
buffer.Truncate(virtioNetHdrLen)
buf.CopyMulti(buffer.Extend(n), buffers)
_, err := t.tunFile.Write(buffer.Bytes())
buffer.Release()
return err
} else {
return t.tunWriter.WriteVectorised(buffers)
}
}
func (t *NativeTun) FrontHeadroom() int {
if t.vnetHdr {
return virtioNetHdrLen
@@ -837,14 +817,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
it.Family = unix.AF_INET
rules = append(rules, it)
}
if p4 && !t.options.StrictRoute {
it = netlink.NewRule()
it.Priority = priority
it.IPProto = syscall.IPPROTO_ICMP
it.Goto = nopPriority
it.Family = unix.AF_INET
rules = append(rules, it)
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
@@ -855,16 +827,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
it.Family = unix.AF_INET6
rules = append(rules, it)
}
if p6 && !t.options.StrictRoute {
it = netlink.NewRule()
it.Priority = priority6
it.IPProto = syscall.IPPROTO_ICMPV6
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
}
if p4 {
it = netlink.NewRule()
@@ -1042,6 +1004,24 @@ func (t *NativeTun) unsetRules() error {
return nil
}
func (t *NativeTun) unsetAddresses() {
if t.options.FileDescriptor > 0 {
return
}
tunLink, err := netlink.LinkByName(t.options.Name)
if err != nil {
return
}
for _, address := range t.options.Inet4Address {
addr, _ := netlink.ParseAddr(address.String())
_ = netlink.AddrDel(tunLink, addr)
}
for _, address := range t.options.Inet6Address {
addr, _ := netlink.ParseAddr(address.String())
_ = netlink.AddrDel(tunLink, addr)
}
}
func (t *NativeTun) resetRules() error {
t.unsetRules()
return t.setRules()

View File

@@ -3,15 +3,44 @@
package tun
import (
"github.com/sagernet/gvisor/pkg/rawfile"
"github.com/sagernet/gvisor/pkg/tcpip/link/fdbased"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"golang.org/x/sys/unix"
)
func init() {
fdbased.BufConfig = []int{65535}
}
var _ GVisorTun = (*NativeTun)(nil)
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) {
iovecs := t.iovecsOutputDefault
var dataLen int
for _, packetSlice := range pkt.AsSlices() {
dataLen += len(packetSlice)
iovec := unix.Iovec{
Base: &packetSlice[0],
}
iovec.SetLen(len(packetSlice))
iovecs = append(iovecs, iovec)
}
if cap(iovecs) > cap(t.iovecsOutputDefault) {
t.iovecsOutputDefault = iovecs[:0]
}
errno := rawfile.NonBlockingWriteIovec(t.tunFd, iovecs)
if errno == 0 {
return dataLen, nil
} else {
return 0, errno
}
}
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) {
if t.vnetHdr {
return fdbased.New(&fdbased.Options{
ep, err := fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
GSOMaxSize: gsoMaxSize,
@@ -19,11 +48,20 @@ func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
if err != nil {
return nil, stack.NICOptions{}, err
}
return ep, stack.NICOptions{}, nil
} else {
ep, err := fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
if err != nil {
return nil, stack.NICOptions{}, err
}
return ep, stack.NICOptions{}, nil
}
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
}

View File

@@ -9,6 +9,7 @@ import (
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
@@ -16,8 +17,6 @@ import (
"github.com/sagernet/sing-tun/internal/winsys"
"github.com/sagernet/sing-tun/internal/wintun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/windnsapi"
@@ -163,6 +162,7 @@ func (t *NativeTun) Name() (string, error) {
}
func (t *NativeTun) Start() error {
t.options.InterfaceMonitor.RegisterMyInterface(t.options.Name)
if !t.options.AutoRoute {
return nil
}
@@ -181,6 +181,13 @@ func (t *NativeTun) Start() error {
return err
}
if t.options.StrictRoute {
major, _, _ := windows.RtlGetNtVersionNumbers()
if major < 10 {
if t.options.Logger != nil {
t.options.Logger.Warn("strict routing is not supported on Windows versions below 10")
}
return nil
}
var engine uintptr
session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
@@ -395,15 +402,16 @@ retry:
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
t.running.Add(1)
defer t.running.Done()
retry:
if t.close.Load() == 1 {
t.running.Done()
return nil, nil, os.ErrClosed
}
start := nanotime()
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if t.close.Load() == 1 {
t.running.Done()
return nil, nil, os.ErrClosed
}
packet, err := t.session.ReceivePacket()
@@ -411,7 +419,10 @@ retry:
case nil:
packetSize := len(packet)
t.rate.update(uint64(packetSize))
return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil
return packet, func() {
t.session.ReleaseReceivePacket(packet)
t.running.Done()
}, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
@@ -420,10 +431,13 @@ retry:
procyield(1)
continue
case windows.ERROR_HANDLE_EOF:
t.running.Done()
return nil, nil, os.ErrClosed
case windows.ERROR_INVALID_DATA:
t.running.Done()
return nil, nil, errors.New("send ring corrupt")
}
t.running.Done()
return nil, nil, fmt.Errorf("read failed: %w", err)
}
}
@@ -516,11 +530,6 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
return 0, fmt.Errorf("write failed: %w", err)
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
defer buf.ReleaseMulti(buffers)
return common.Error(t.write(buf.ToSliceMulti(buffers)))
}
func (t *NativeTun) Close() error {
var err error
t.closeOnce.Do(func() {

View File

@@ -11,8 +11,12 @@ import (
var _ GVisorTun = (*NativeTun)(nil)
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
return &WintunEndpoint{tun: t}, nil
func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) {
return t.write(pkt.AsSlices())
}
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) {
return &WintunEndpoint{tun: t}, stack.NICOptions{}, nil
}
var _ stack.LinkEndpoint = (*WintunEndpoint)(nil)