Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
42a84746a9 | ||
|
|
7f3af59109 | ||
|
|
8ccd51404e | ||
|
|
7bd004f141 | ||
|
|
d44e0c68d4 | ||
|
|
bc23daa800 | ||
|
|
e6c219a61e | ||
|
|
ddc824fb9c | ||
|
|
e229d7041e | ||
|
|
e2503223dc | ||
|
|
92529635cb | ||
|
|
37e2523a36 | ||
|
|
2646115abb | ||
|
|
a68ba22714 | ||
|
|
a385766b3f | ||
|
|
4a56d47035 | ||
|
|
07e21b9170 | ||
|
|
ebbe32588c | ||
|
|
0310956cc0 | ||
|
|
3af7305b85 | ||
|
|
aa1fd4d994 | ||
|
|
7812930a48 | ||
|
|
4c81c8a62a | ||
|
|
a0881ada32 | ||
|
|
8763c24e49 | ||
|
|
5e343c4b66 | ||
|
|
f57754918d | ||
|
|
2121bc3f01 | ||
|
|
bea26198e7 | ||
|
|
3df19f464e | ||
|
|
494b0ef858 | ||
|
|
f13cd94aa0 | ||
|
|
51ac6b34f1 | ||
|
|
31e29f93cc | ||
|
|
c410f7050c | ||
|
|
d89ab3f207 | ||
|
|
219c612399 | ||
|
|
a8ce3838bc | ||
|
|
35b5747b44 | ||
|
|
5cb6d27288 | ||
|
|
9105485a50 | ||
|
|
57aba1a5c4 | ||
|
|
7f3343169a | ||
|
|
22b811f938 | ||
|
|
618be14c7b | ||
|
|
c8c2984261 |
1
Makefile
1
Makefile
@@ -29,4 +29,5 @@ lint_install:
|
||||
|
||||
test:
|
||||
go build -v .
|
||||
go test -bench=. ./internal/checksum_test
|
||||
#go test -v .
|
||||
|
||||
2
go.mod
2
go.mod
@@ -9,7 +9,7 @@ require (
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
|
||||
github.com/sagernet/nftables v0.3.0-beta.4
|
||||
github.com/sagernet/sing v0.6.0-beta.2
|
||||
github.com/sagernet/sing v0.7.6
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
|
||||
golang.org/x/net v0.26.0
|
||||
|
||||
4
go.sum
4
go.sum
@@ -22,8 +22,8 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ=
|
||||
github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.7.6 h1:6LBfDH+aI/26J3r9UHlaxTNjJeMhBpU/wrk0JKDZYI4=
|
||||
github.com/sagernet/sing v0.7.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
|
||||
@@ -28,6 +28,6 @@ func BenchmarkGChecksum(b *testing.B) {
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
checksum.Checksum(packet[i%1000], 0)
|
||||
checksum.ChecksumDefault(packet[i%1000], 0)
|
||||
}
|
||||
}
|
||||
|
||||
668
internal/fdbased_darwin/endpoint.go
Normal file
668
internal/fdbased_darwin/endpoint.go
Normal file
@@ -0,0 +1,668 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package fdbased provides the implementation of data-link layer endpoints
|
||||
// backed by boundary-preserving file descriptors (e.g., TUN devices,
|
||||
// seqpacket/datagram sockets).
|
||||
//
|
||||
// FD based endpoints can be used in the networking stack by calling New() to
|
||||
// create a new endpoint, and then passing it as an argument to
|
||||
// Stack.CreateNIC().
|
||||
//
|
||||
// FD based endpoints can use more than one file descriptor to read incoming
|
||||
// packets. If there are more than one FDs specified and the underlying FD is an
|
||||
// AF_PACKET then the endpoint will enable FANOUT mode on the socket so that the
|
||||
// host kernel will consistently hash the packets to the sockets. This ensures
|
||||
// that packets for the same TCP streams are not reordered.
|
||||
//
|
||||
// Similarly if more than one FD's are specified where the underlying FD is not
|
||||
// AF_PACKET then it's the caller's responsibility to ensure that all inbound
|
||||
// packets on the descriptors are consistently 5 tuple hashed to one of the
|
||||
// descriptors to prevent TCP reordering.
|
||||
//
|
||||
// Since netstack today does not compute 5 tuple hashes for outgoing packets we
|
||||
// only use the first FD to write outbound packets. Once 5 tuple hashes for
|
||||
// all outbound packets are available we will make use of all underlying FD's to
|
||||
// write outbound packets.
|
||||
package fdbased
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/sync"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
|
||||
"github.com/sagernet/sing/common"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// linkDispatcher reads packets from the link FD and dispatches them to the
|
||||
// NetworkDispatcher.
|
||||
type linkDispatcher interface {
|
||||
Stop()
|
||||
dispatch() (bool, tcpip.Error)
|
||||
release()
|
||||
}
|
||||
|
||||
// PacketDispatchMode are the various supported methods of receiving and
|
||||
// dispatching packets from the underlying FD.
|
||||
type PacketDispatchMode int
|
||||
|
||||
// BatchSize is the number of packets to write in each syscall. It is 47
|
||||
// because when GVisorGSO is in use then a single 65KB TCP segment can get
|
||||
// split into 46 segments of 1420 bytes and a single 216 byte segment.
|
||||
const BatchSize = 47
|
||||
|
||||
const (
|
||||
// Readv is the default dispatch mode and is the least performant of the
|
||||
// dispatch options but the one that is supported by all underlying FD
|
||||
// types.
|
||||
Readv PacketDispatchMode = iota
|
||||
// RecvMMsg enables use of recvmmsg() syscall instead of readv() to
|
||||
// read inbound packets. This reduces # of syscalls needed to process
|
||||
// packets.
|
||||
//
|
||||
// NOTE: recvmmsg() is only supported for sockets, so if the underlying
|
||||
// FD is not a socket then the code will still fall back to the readv()
|
||||
// path.
|
||||
RecvMMsg
|
||||
// PacketMMap enables use of PACKET_RX_RING to receive packets from the
|
||||
// NIC. PacketMMap requires that the underlying FD be an AF_PACKET. The
|
||||
// primary use-case for this is runsc which uses an AF_PACKET FD to
|
||||
// receive packets from the veth device.
|
||||
PacketMMap
|
||||
)
|
||||
|
||||
func (p PacketDispatchMode) String() string {
|
||||
switch p {
|
||||
case Readv:
|
||||
return "Readv"
|
||||
case RecvMMsg:
|
||||
return "RecvMMsg"
|
||||
case PacketMMap:
|
||||
return "PacketMMap"
|
||||
default:
|
||||
return fmt.Sprintf("unknown packet dispatch mode '%d'", p)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ stack.LinkEndpoint = (*endpoint)(nil)
|
||||
_ stack.GSOEndpoint = (*endpoint)(nil)
|
||||
)
|
||||
|
||||
// +stateify savable
|
||||
type fdInfo struct {
|
||||
fd int
|
||||
isSocket bool
|
||||
}
|
||||
|
||||
// +stateify savable
|
||||
type endpoint struct {
|
||||
// fds is the set of file descriptors each identifying one inbound/outbound
|
||||
// channel. The endpoint will dispatch from all inbound channels as well as
|
||||
// hash outbound packets to specific channels based on the packet hash.
|
||||
fds []fdInfo
|
||||
|
||||
// hdrSize specifies the link-layer header size. If set to 0, no header
|
||||
// is added/removed; otherwise an ethernet header is used.
|
||||
hdrSize int
|
||||
|
||||
// caps holds the endpoint capabilities.
|
||||
caps stack.LinkEndpointCapabilities
|
||||
|
||||
// closed is a function to be called when the FD's peer (if any) closes
|
||||
// its end of the communication pipe.
|
||||
closed func(tcpip.Error) `state:"nosave"`
|
||||
|
||||
inboundDispatchers []linkDispatcher
|
||||
|
||||
mu endpointRWMutex `state:"nosave"`
|
||||
// +checklocks:mu
|
||||
dispatcher stack.NetworkDispatcher
|
||||
|
||||
// packetDispatchMode controls the packet dispatcher used by this
|
||||
// endpoint.
|
||||
packetDispatchMode PacketDispatchMode
|
||||
|
||||
// wg keeps track of running goroutines.
|
||||
wg sync.WaitGroup `state:"nosave"`
|
||||
|
||||
// maxSyscallHeaderBytes has the same meaning as
|
||||
// Options.MaxSyscallHeaderBytes.
|
||||
maxSyscallHeaderBytes uintptr
|
||||
|
||||
// writevMaxIovs is the maximum number of iovecs that may be passed to
|
||||
// rawfile.NonBlockingWriteIovec, as possibly limited by
|
||||
// maxSyscallHeaderBytes. (No analogous limit is defined for
|
||||
// rawfile.NonBlockingSendMMsg, since in that case the maximum number of
|
||||
// iovecs also depends on the number of mmsghdrs. Instead, if sendBatch
|
||||
// encounters a packet whose iovec count is limited by
|
||||
// maxSyscallHeaderBytes, it falls back to writing the packet using writev
|
||||
// via WritePacket.)
|
||||
writevMaxIovs int
|
||||
|
||||
// addr is the address of the endpoint.
|
||||
//
|
||||
// +checklocks:mu
|
||||
addr tcpip.LinkAddress
|
||||
|
||||
// mtu (maximum transmission unit) is the maximum size of a packet.
|
||||
// +checklocks:mu
|
||||
mtu uint32
|
||||
|
||||
batchSize int
|
||||
sendMsgX bool
|
||||
}
|
||||
|
||||
// Options specify the details about the fd-based endpoint to be created.
|
||||
//
|
||||
// +stateify savable
|
||||
type Options struct {
|
||||
// FDs is a set of FDs used to read/write packets.
|
||||
FDs []int
|
||||
|
||||
// MTU is the mtu to use for this endpoint.
|
||||
MTU uint32
|
||||
|
||||
// EthernetHeader if true, indicates that the endpoint should read/write
|
||||
// ethernet frames instead of IP packets.
|
||||
EthernetHeader bool
|
||||
|
||||
// ClosedFunc is a function to be called when an endpoint's peer (if
|
||||
// any) closes its end of the communication pipe.
|
||||
ClosedFunc func(tcpip.Error)
|
||||
|
||||
// Address is the link address for this endpoint. Only used if
|
||||
// EthernetHeader is true.
|
||||
Address tcpip.LinkAddress
|
||||
|
||||
// SaveRestore if true, indicates that this NIC capability set should
|
||||
// include CapabilitySaveRestore
|
||||
SaveRestore bool
|
||||
|
||||
// DisconnectOk if true, indicates that this NIC capability set should
|
||||
// include CapabilityDisconnectOk.
|
||||
DisconnectOk bool
|
||||
|
||||
// PacketDispatchMode specifies the type of inbound dispatcher to be
|
||||
// used for this endpoint.
|
||||
PacketDispatchMode PacketDispatchMode
|
||||
|
||||
// TXChecksumOffload if true, indicates that this endpoints capability
|
||||
// set should include CapabilityTXChecksumOffload.
|
||||
TXChecksumOffload bool
|
||||
|
||||
// RXChecksumOffload if true, indicates that this endpoints capability
|
||||
// set should include CapabilityRXChecksumOffload.
|
||||
RXChecksumOffload bool
|
||||
|
||||
// If MaxSyscallHeaderBytes is non-zero, it is the maximum number of bytes
|
||||
// of struct iovec, msghdr, and mmsghdr that may be passed by each host
|
||||
// system call.
|
||||
MaxSyscallHeaderBytes int
|
||||
|
||||
// InterfaceIndex is the interface index of the underlying device.
|
||||
InterfaceIndex int
|
||||
|
||||
// ProcessorsPerChannel is the number of goroutines used to handle packets
|
||||
// from each FD.
|
||||
ProcessorsPerChannel int
|
||||
|
||||
MultiPendingPackets bool
|
||||
SendMsgX bool
|
||||
}
|
||||
|
||||
// New creates a new fd-based endpoint.
|
||||
//
|
||||
// Makes fd non-blocking, but does not take ownership of fd, which must remain
|
||||
// open for the lifetime of the returned endpoint (until after the endpoint has
|
||||
// stopped being using and Wait returns).
|
||||
func New(opts *Options) (stack.LinkEndpoint, error) {
|
||||
caps := stack.LinkEndpointCapabilities(0)
|
||||
if opts.RXChecksumOffload {
|
||||
caps |= stack.CapabilityRXChecksumOffload
|
||||
}
|
||||
|
||||
if opts.TXChecksumOffload {
|
||||
caps |= stack.CapabilityTXChecksumOffload
|
||||
}
|
||||
|
||||
hdrSize := 0
|
||||
if opts.EthernetHeader {
|
||||
hdrSize = header.EthernetMinimumSize
|
||||
caps |= stack.CapabilityResolutionRequired
|
||||
}
|
||||
|
||||
if opts.SaveRestore {
|
||||
caps |= stack.CapabilitySaveRestore
|
||||
}
|
||||
|
||||
if opts.DisconnectOk {
|
||||
caps |= stack.CapabilityDisconnectOk
|
||||
}
|
||||
|
||||
if len(opts.FDs) == 0 {
|
||||
return nil, fmt.Errorf("opts.FD is empty, at least one FD must be specified")
|
||||
}
|
||||
|
||||
if opts.MaxSyscallHeaderBytes < 0 {
|
||||
return nil, fmt.Errorf("opts.MaxSyscallHeaderBytes is negative")
|
||||
}
|
||||
var batchSize int
|
||||
if opts.MultiPendingPackets {
|
||||
batchSize = int((512*1024)/(opts.MTU)) + 1
|
||||
} else {
|
||||
batchSize = 1
|
||||
}
|
||||
|
||||
e := &endpoint{
|
||||
mtu: opts.MTU,
|
||||
caps: caps,
|
||||
closed: opts.ClosedFunc,
|
||||
addr: opts.Address,
|
||||
hdrSize: hdrSize,
|
||||
packetDispatchMode: opts.PacketDispatchMode,
|
||||
maxSyscallHeaderBytes: uintptr(opts.MaxSyscallHeaderBytes),
|
||||
writevMaxIovs: rawfile.MaxIovs,
|
||||
batchSize: batchSize,
|
||||
sendMsgX: opts.SendMsgX,
|
||||
}
|
||||
if e.maxSyscallHeaderBytes != 0 {
|
||||
if max := int(e.maxSyscallHeaderBytes / rawfile.SizeofIovec); max < e.writevMaxIovs {
|
||||
e.writevMaxIovs = max
|
||||
}
|
||||
}
|
||||
|
||||
// Create per channel dispatchers.
|
||||
for _, fd := range opts.FDs {
|
||||
if err := unix.SetNonblock(fd, true); err != nil {
|
||||
return nil, fmt.Errorf("unix.SetNonblock(%v) failed: %v", fd, err)
|
||||
}
|
||||
|
||||
e.fds = append(e.fds, fdInfo{fd: fd, isSocket: true})
|
||||
if opts.ProcessorsPerChannel == 0 {
|
||||
opts.ProcessorsPerChannel = common.Max(1, runtime.GOMAXPROCS(0)/len(opts.FDs))
|
||||
}
|
||||
|
||||
inboundDispatcher, err := newRecvMMsgDispatcher(fd, e, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("createInboundDispatcher(...) = %v", err)
|
||||
}
|
||||
e.inboundDispatchers = append(e.inboundDispatchers, inboundDispatcher)
|
||||
}
|
||||
|
||||
return e, nil
|
||||
}
|
||||
|
||||
// Attach launches the goroutine that reads packets from the file descriptor and
|
||||
// dispatches them via the provided dispatcher. If one is already attached,
|
||||
// then nothing happens.
|
||||
//
|
||||
// Attach implements stack.LinkEndpoint.Attach.
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
e.mu.Lock()
|
||||
|
||||
// nil means the NIC is being removed.
|
||||
if dispatcher == nil && e.dispatcher != nil {
|
||||
for _, dispatcher := range e.inboundDispatchers {
|
||||
dispatcher.Stop()
|
||||
}
|
||||
e.dispatcher = nil
|
||||
// NOTE(gvisor.dev/issue/11456): Unlock e.mu before e.Wait().
|
||||
e.mu.Unlock()
|
||||
e.Wait()
|
||||
return
|
||||
}
|
||||
defer e.mu.Unlock()
|
||||
if dispatcher != nil && e.dispatcher == nil {
|
||||
e.dispatcher = dispatcher
|
||||
// Link endpoints are not savable. When transportation endpoints are
|
||||
// saved, they stop sending outgoing packets and all incoming packets
|
||||
// are rejected.
|
||||
for i := range e.inboundDispatchers {
|
||||
e.wg.Add(1)
|
||||
go func(i int) { // S/R-SAFE: See above.
|
||||
e.dispatchLoop(e.inboundDispatchers[i])
|
||||
e.wg.Done()
|
||||
}(i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsAttached implements stack.LinkEndpoint.IsAttached.
|
||||
func (e *endpoint) IsAttached() bool {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
// MTU implements stack.LinkEndpoint.MTU.
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
return e.mtu
|
||||
}
|
||||
|
||||
// SetMTU implements stack.LinkEndpoint.SetMTU.
|
||||
func (e *endpoint) SetMTU(mtu uint32) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.mtu = mtu
|
||||
}
|
||||
|
||||
// Capabilities implements stack.LinkEndpoint.Capabilities.
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return e.caps
|
||||
}
|
||||
|
||||
// MaxHeaderLength returns the maximum size of the link-layer header.
|
||||
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||
return uint16(e.hdrSize)
|
||||
}
|
||||
|
||||
// LinkAddress returns the link address of this endpoint.
|
||||
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// SetLinkAddress implements stack.LinkEndpoint.SetLinkAddress.
|
||||
func (e *endpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.addr = addr
|
||||
}
|
||||
|
||||
// Wait implements stack.LinkEndpoint.Wait. It waits for the endpoint to stop
|
||||
// reading from its FD.
|
||||
func (e *endpoint) Wait() {
|
||||
e.wg.Wait()
|
||||
}
|
||||
|
||||
// AddHeader implements stack.LinkEndpoint.AddHeader.
|
||||
func (e *endpoint) AddHeader(pkt *stack.PacketBuffer) {
|
||||
if e.hdrSize > 0 {
|
||||
// Add ethernet header if needed.
|
||||
eth := header.Ethernet(pkt.LinkHeader().Push(header.EthernetMinimumSize))
|
||||
eth.Encode(&header.EthernetFields{
|
||||
SrcAddr: pkt.EgressRoute.LocalLinkAddress,
|
||||
DstAddr: pkt.EgressRoute.RemoteLinkAddress,
|
||||
Type: pkt.NetworkProtocolNumber,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *endpoint) parseHeader(pkt *stack.PacketBuffer) (header.Ethernet, bool) {
|
||||
if e.hdrSize <= 0 {
|
||||
return nil, true
|
||||
}
|
||||
hdrBytes, ok := pkt.LinkHeader().Consume(e.hdrSize)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
hdr := header.Ethernet(hdrBytes)
|
||||
pkt.NetworkProtocolNumber = hdr.Type()
|
||||
return hdr, true
|
||||
}
|
||||
|
||||
// parseInboundHeader parses the link header of pkt and returns true if the
|
||||
// header is well-formed and sent to this endpoint's MAC or the broadcast
|
||||
// address.
|
||||
func (e *endpoint) parseInboundHeader(pkt *stack.PacketBuffer, wantAddr tcpip.LinkAddress) bool {
|
||||
hdr, ok := e.parseHeader(pkt)
|
||||
if !ok || e.hdrSize <= 0 {
|
||||
return ok
|
||||
}
|
||||
dstAddr := hdr.DestinationAddress()
|
||||
// Per RFC 9542 2.1 on the least significant bit of the first octet of
|
||||
// a MAC address: "If it is zero, the MAC address is unicast. If it is
|
||||
// a one, the address is groupcast (multicast or broadcast)." Multicast
|
||||
// and broadcast are the same thing to ethernet; they are both sent to
|
||||
// everyone.
|
||||
return dstAddr == wantAddr || byte(dstAddr[0])&0x01 == 1
|
||||
}
|
||||
|
||||
// ParseHeader implements stack.LinkEndpoint.ParseHeader.
|
||||
func (e *endpoint) ParseHeader(pkt *stack.PacketBuffer) bool {
|
||||
_, ok := e.parseHeader(pkt)
|
||||
return ok
|
||||
}
|
||||
|
||||
var (
|
||||
packetHeader4 = []byte{0x00, 0x00, 0x00, unix.AF_INET}
|
||||
packetHeader6 = []byte{0x00, 0x00, 0x00, unix.AF_INET6}
|
||||
)
|
||||
|
||||
// writePacket writes outbound packets to the file descriptor. If it is not
|
||||
// currently writable, the packet is dropped.
|
||||
func (e *endpoint) writePacket(pkt *stack.PacketBuffer) tcpip.Error {
|
||||
fdInfo := e.fds[pkt.Hash%uint32(len(e.fds))]
|
||||
fd := fdInfo.fd
|
||||
var vnetHdrBuf []byte
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
vnetHdrBuf = packetHeader4
|
||||
} else {
|
||||
vnetHdrBuf = packetHeader6
|
||||
}
|
||||
views := pkt.AsSlices()
|
||||
numIovecs := len(views)
|
||||
if len(vnetHdrBuf) != 0 {
|
||||
numIovecs++
|
||||
}
|
||||
if numIovecs > e.writevMaxIovs {
|
||||
numIovecs = e.writevMaxIovs
|
||||
}
|
||||
|
||||
// Allocate small iovec arrays on the stack.
|
||||
var iovecsArr [8]unix.Iovec
|
||||
iovecs := iovecsArr[:0]
|
||||
if numIovecs > len(iovecsArr) {
|
||||
iovecs = make([]unix.Iovec, 0, numIovecs)
|
||||
}
|
||||
iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
|
||||
for _, v := range views {
|
||||
iovecs = rawfile.AppendIovecFromBytes(iovecs, v, numIovecs)
|
||||
}
|
||||
if errno := rawfile.NonBlockingWriteIovec(fd, iovecs); errno != 0 {
|
||||
return TranslateErrno(errno)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *endpoint) sendBatch(batchFDInfo fdInfo, pkts []*stack.PacketBuffer) (int, tcpip.Error) {
|
||||
// Degrade to writePacket if underlying fd is not a socket.
|
||||
if !batchFDInfo.isSocket || !e.sendMsgX {
|
||||
var written int
|
||||
var err tcpip.Error
|
||||
for written < len(pkts) {
|
||||
if err = e.writePacket(pkts[written]); err != nil {
|
||||
break
|
||||
}
|
||||
written++
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
|
||||
// Send a batch of packets through batchFD.
|
||||
batchFD := batchFDInfo.fd
|
||||
mmsgHdrsStorage := make([]rawfile.MsgHdrX, 0, len(pkts))
|
||||
packets := 0
|
||||
for packets < len(pkts) {
|
||||
mmsgHdrs := mmsgHdrsStorage
|
||||
batch := pkts[packets:]
|
||||
syscallHeaderBytes := uintptr(0)
|
||||
for _, pkt := range batch {
|
||||
var vnetHdrBuf []byte
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
vnetHdrBuf = packetHeader4
|
||||
} else {
|
||||
vnetHdrBuf = packetHeader6
|
||||
}
|
||||
views, offset := pkt.AsViewList()
|
||||
var skipped int
|
||||
var view *buffer.View
|
||||
for view = views.Front(); view != nil && offset >= view.Size(); view = view.Next() {
|
||||
offset -= view.Size()
|
||||
skipped++
|
||||
}
|
||||
|
||||
// We've made it to the usable views.
|
||||
numIovecs := views.Len() - skipped
|
||||
if len(vnetHdrBuf) != 0 {
|
||||
numIovecs++
|
||||
}
|
||||
if numIovecs > rawfile.MaxIovs {
|
||||
numIovecs = rawfile.MaxIovs
|
||||
}
|
||||
if e.maxSyscallHeaderBytes != 0 {
|
||||
syscallHeaderBytes += rawfile.SizeofMsgHdrX + uintptr(numIovecs)*rawfile.SizeofIovec
|
||||
if syscallHeaderBytes > e.maxSyscallHeaderBytes {
|
||||
// We can't fit this packet into this call to sendmmsg().
|
||||
// We could potentially do so if we reduced numIovecs
|
||||
// further, but this might incur considerable extra
|
||||
// copying. Leave it to the next batch instead.
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// We can't easily allocate iovec arrays on the stack here since
|
||||
// they will escape this loop iteration via mmsgHdrs.
|
||||
iovecs := make([]unix.Iovec, 0, numIovecs)
|
||||
iovecs = rawfile.AppendIovecFromBytes(iovecs, vnetHdrBuf, numIovecs)
|
||||
// At most one slice has a non-zero offset.
|
||||
iovecs = rawfile.AppendIovecFromBytes(iovecs, view.AsSlice()[offset:], numIovecs)
|
||||
for view = view.Next(); view != nil; view = view.Next() {
|
||||
iovecs = rawfile.AppendIovecFromBytes(iovecs, view.AsSlice(), numIovecs)
|
||||
}
|
||||
|
||||
var mmsgHdr rawfile.MsgHdrX
|
||||
mmsgHdr.Msg.Iov = &iovecs[0]
|
||||
mmsgHdr.Msg.SetIovlen(len(iovecs))
|
||||
// mmsgHdr.DataLen = uint32(len(iovecs))
|
||||
mmsgHdrs = append(mmsgHdrs, mmsgHdr)
|
||||
}
|
||||
|
||||
if len(mmsgHdrs) == 0 {
|
||||
// We can't fit batch[0] into a mmsghdr while staying under
|
||||
// e.maxSyscallHeaderBytes. Use WritePacket, which will avoid the
|
||||
// mmsghdr (by using writev) and re-buffer iovecs more aggressively
|
||||
// if necessary (by using e.writevMaxIovs instead of
|
||||
// rawfile.MaxIovs).
|
||||
pkt := batch[0]
|
||||
if err := e.writePacket(pkt); err != nil {
|
||||
return packets, err
|
||||
}
|
||||
packets++
|
||||
} else {
|
||||
for len(mmsgHdrs) > 0 {
|
||||
sent, errno := rawfile.NonBlockingSendMMsg(batchFD, mmsgHdrs)
|
||||
if errno != 0 {
|
||||
return packets, TranslateErrno(errno)
|
||||
}
|
||||
packets += sent
|
||||
mmsgHdrs = mmsgHdrs[sent:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return packets, nil
|
||||
}
|
||||
|
||||
// WritePackets writes outbound packets to the underlying file descriptors. If
|
||||
// one is not currently writable, the packet is dropped.
|
||||
//
|
||||
// Being a batch API, each packet in pkts should have the following
|
||||
// fields populated:
|
||||
// - pkt.EgressRoute
|
||||
// - pkt.GSOOptions
|
||||
// - pkt.NetworkProtocolNumber
|
||||
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||
// Preallocate to avoid repeated reallocation as we append to batch.
|
||||
batch := make([]*stack.PacketBuffer, 0, e.batchSize)
|
||||
batchFDInfo := fdInfo{fd: -1, isSocket: false}
|
||||
sentPackets := 0
|
||||
for _, pkt := range pkts.AsSlice() {
|
||||
if len(batch) == 0 {
|
||||
batchFDInfo = e.fds[pkt.Hash%uint32(len(e.fds))]
|
||||
}
|
||||
pktFDInfo := e.fds[pkt.Hash%uint32(len(e.fds))]
|
||||
if sendNow := pktFDInfo != batchFDInfo; !sendNow {
|
||||
batch = append(batch, pkt)
|
||||
continue
|
||||
}
|
||||
n, err := e.sendBatch(batchFDInfo, batch)
|
||||
sentPackets += n
|
||||
if err != nil {
|
||||
return sentPackets, err
|
||||
}
|
||||
batch = batch[:0]
|
||||
batch = append(batch, pkt)
|
||||
batchFDInfo = pktFDInfo
|
||||
}
|
||||
|
||||
if len(batch) != 0 {
|
||||
n, err := e.sendBatch(batchFDInfo, batch)
|
||||
sentPackets += n
|
||||
if err != nil {
|
||||
return sentPackets, err
|
||||
}
|
||||
}
|
||||
return sentPackets, nil
|
||||
}
|
||||
|
||||
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
|
||||
// them to the network stack.
|
||||
func (e *endpoint) dispatchLoop(inboundDispatcher linkDispatcher) tcpip.Error {
|
||||
for {
|
||||
cont, err := inboundDispatcher.dispatch()
|
||||
if err != nil || !cont {
|
||||
if e.closed != nil {
|
||||
e.closed(err)
|
||||
}
|
||||
inboundDispatcher.release()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GSOMaxSize implements stack.GSOEndpoint.
|
||||
func (e *endpoint) GSOMaxSize() uint32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// SupportedGSO implements stack.GSOEndpoint.
|
||||
func (e *endpoint) SupportedGSO() stack.SupportedGSO {
|
||||
return stack.GSONotSupported
|
||||
}
|
||||
|
||||
// ARPHardwareType implements stack.LinkEndpoint.ARPHardwareType.
|
||||
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
if e.hdrSize > 0 {
|
||||
return header.ARPHardwareEther
|
||||
}
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
// Close implements stack.LinkEndpoint.
|
||||
func (e *endpoint) Close() {}
|
||||
|
||||
// SetOnCloseAction implements stack.LinkEndpoint.
|
||||
func (*endpoint) SetOnCloseAction(func()) {}
|
||||
96
internal/fdbased_darwin/endpoint_mutex.go
Normal file
96
internal/fdbased_darwin/endpoint_mutex.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package fdbased
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/sync"
|
||||
"github.com/sagernet/gvisor/pkg/sync/locking"
|
||||
)
|
||||
|
||||
// RWMutex is sync.RWMutex with the correctness validator.
|
||||
type endpointRWMutex struct {
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// lockNames is a list of user-friendly lock names.
|
||||
// Populated in init.
|
||||
var endpointlockNames []string
|
||||
|
||||
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
|
||||
// referring to an index within lockNames.
|
||||
// Values are specified using the "consts" field of go_template_instance.
|
||||
type endpointlockNameIndex int
|
||||
|
||||
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
|
||||
// LOCK_NAME_INDEX_CONSTANTS
|
||||
const ()
|
||||
|
||||
// Lock locks m.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) Lock() {
|
||||
locking.AddGLock(endpointprefixIndex, -1)
|
||||
m.mu.Lock()
|
||||
}
|
||||
|
||||
// NestedLock locks m knowing that another lock of the same type is held.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) NestedLock(i endpointlockNameIndex) {
|
||||
locking.AddGLock(endpointprefixIndex, int(i))
|
||||
m.mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock unlocks m.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) Unlock() {
|
||||
m.mu.Unlock()
|
||||
locking.DelGLock(endpointprefixIndex, -1)
|
||||
}
|
||||
|
||||
// NestedUnlock unlocks m knowing that another lock of the same type is held.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) NestedUnlock(i endpointlockNameIndex) {
|
||||
m.mu.Unlock()
|
||||
locking.DelGLock(endpointprefixIndex, int(i))
|
||||
}
|
||||
|
||||
// RLock locks m for reading.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) RLock() {
|
||||
locking.AddGLock(endpointprefixIndex, -1)
|
||||
m.mu.RLock()
|
||||
}
|
||||
|
||||
// RUnlock undoes a single RLock call.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) RUnlock() {
|
||||
m.mu.RUnlock()
|
||||
locking.DelGLock(endpointprefixIndex, -1)
|
||||
}
|
||||
|
||||
// RLockBypass locks m for reading without executing the validator.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) RLockBypass() {
|
||||
m.mu.RLock()
|
||||
}
|
||||
|
||||
// RUnlockBypass undoes a single RLockBypass call.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) RUnlockBypass() {
|
||||
m.mu.RUnlock()
|
||||
}
|
||||
|
||||
// DowngradeLock atomically unlocks rw for writing and locks it for reading.
|
||||
// +checklocksignore
|
||||
func (m *endpointRWMutex) DowngradeLock() {
|
||||
m.mu.DowngradeLock()
|
||||
}
|
||||
|
||||
var endpointprefixIndex *locking.MutexClass
|
||||
|
||||
// DO NOT REMOVE: The following function is automatically replaced.
|
||||
func endpointinitLockNames() {}
|
||||
|
||||
func init() {
|
||||
endpointinitLockNames()
|
||||
endpointprefixIndex = locking.NewMutexClass(reflect.TypeOf(endpointRWMutex{}), endpointlockNames)
|
||||
}
|
||||
54
internal/fdbased_darwin/errno.go
Normal file
54
internal/fdbased_darwin/errno.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package fdbased
|
||||
|
||||
import (
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TranslateErrno(e unix.Errno) tcpip.Error {
|
||||
switch e {
|
||||
case unix.EEXIST:
|
||||
return &tcpip.ErrDuplicateAddress{}
|
||||
case unix.ENETUNREACH:
|
||||
return &tcpip.ErrHostUnreachable{}
|
||||
case unix.EINVAL:
|
||||
return &tcpip.ErrInvalidEndpointState{}
|
||||
case unix.EALREADY:
|
||||
return &tcpip.ErrAlreadyConnecting{}
|
||||
case unix.EISCONN:
|
||||
return &tcpip.ErrAlreadyConnected{}
|
||||
case unix.EADDRINUSE:
|
||||
return &tcpip.ErrPortInUse{}
|
||||
case unix.EADDRNOTAVAIL:
|
||||
return &tcpip.ErrBadLocalAddress{}
|
||||
case unix.EPIPE:
|
||||
return &tcpip.ErrClosedForSend{}
|
||||
case unix.EWOULDBLOCK:
|
||||
return &tcpip.ErrWouldBlock{}
|
||||
case unix.ECONNREFUSED:
|
||||
return &tcpip.ErrConnectionRefused{}
|
||||
case unix.ETIMEDOUT:
|
||||
return &tcpip.ErrTimeout{}
|
||||
case unix.EINPROGRESS:
|
||||
return &tcpip.ErrConnectStarted{}
|
||||
case unix.EDESTADDRREQ:
|
||||
return &tcpip.ErrDestinationRequired{}
|
||||
case unix.ENOTSUP:
|
||||
return &tcpip.ErrNotSupported{}
|
||||
case unix.ENOTTY:
|
||||
return &tcpip.ErrQueueSizeNotSupported{}
|
||||
case unix.ENOTCONN:
|
||||
return &tcpip.ErrNotConnected{}
|
||||
case unix.ECONNRESET:
|
||||
return &tcpip.ErrConnectionReset{}
|
||||
case unix.ECONNABORTED:
|
||||
return &tcpip.ErrConnectionAborted{}
|
||||
case unix.EMSGSIZE:
|
||||
return &tcpip.ErrMessageTooLong{}
|
||||
case unix.ENOBUFS:
|
||||
return &tcpip.ErrNoBufferSpace{}
|
||||
default:
|
||||
return &tcpip.ErrInvalidEndpointState{}
|
||||
}
|
||||
}
|
||||
199
internal/fdbased_darwin/packet_dispatchers.go
Normal file
199
internal/fdbased_darwin/packet_dispatchers.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fdbased
|
||||
|
||||
import (
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack/gro"
|
||||
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
|
||||
"github.com/sagernet/sing-tun/internal/stopfd_darwin"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type iovecBuffer struct {
|
||||
mtu int
|
||||
views []*buffer.View
|
||||
iovecs []unix.Iovec `state:"nosave"`
|
||||
}
|
||||
|
||||
func newIovecBuffer(mtu uint32) *iovecBuffer {
|
||||
b := &iovecBuffer{
|
||||
mtu: int(mtu),
|
||||
views: make([]*buffer.View, 2),
|
||||
iovecs: make([]unix.Iovec, 2),
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *iovecBuffer) nextIovecs() []unix.Iovec {
|
||||
if b.views[0] == nil {
|
||||
b.views[0] = buffer.NewViewSize(4)
|
||||
b.iovecs[0] = unix.Iovec{Base: b.views[0].BasePtr()}
|
||||
b.iovecs[0].SetLen(4)
|
||||
}
|
||||
if b.views[1] == nil {
|
||||
b.views[1] = buffer.NewViewSize(b.mtu)
|
||||
b.iovecs[1] = unix.Iovec{Base: b.views[1].BasePtr()}
|
||||
b.iovecs[1].SetLen(b.mtu)
|
||||
}
|
||||
return b.iovecs
|
||||
}
|
||||
|
||||
// pullBuffer extracts the enough underlying storage from b.buffer to hold n
|
||||
// bytes. It removes this storage from b.buffer, returns a new buffer
|
||||
// that holds the storage, and updates pulledIndex to indicate which part
|
||||
// of b.buffer's storage must be reallocated during the next call to
|
||||
// nextIovecs.
|
||||
func (b *iovecBuffer) pullBuffer(n int) buffer.Buffer {
|
||||
pulled := buffer.Buffer{}
|
||||
pulled.Append(b.views[0])
|
||||
pulled.Append(b.views[1])
|
||||
pulled.Truncate(int64(n))
|
||||
pulled.TrimFront(4)
|
||||
b.views[0] = nil
|
||||
b.views[1] = nil
|
||||
return pulled
|
||||
}
|
||||
|
||||
func (b *iovecBuffer) release() {
|
||||
for _, v := range b.views {
|
||||
if v != nil {
|
||||
v.Release()
|
||||
v = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
|
||||
// dispatches them.
|
||||
//
|
||||
// +stateify savable
|
||||
type recvMMsgDispatcher struct {
|
||||
stopfd.StopFD
|
||||
// fd is the file descriptor used to send and receive packets.
|
||||
fd int
|
||||
|
||||
// e is the endpoint this dispatcher is attached to.
|
||||
e *endpoint
|
||||
|
||||
// bufs is an array of iovec buffers that contain packet contents.
|
||||
bufs []*iovecBuffer
|
||||
|
||||
// msgHdrs is an array of MMsgHdr objects where each MMsghdr is used to
|
||||
// reference an array of iovecs in the iovecs field defined above. This
|
||||
// array is passed as the parameter to recvmmsg call to retrieve
|
||||
// potentially more than 1 packet per unix.
|
||||
msgHdrs []rawfile.MsgHdrX `state:"nosave"`
|
||||
|
||||
// pkts is reused to avoid allocations.
|
||||
pkts stack.PacketBufferList
|
||||
|
||||
// gro coalesces incoming packets to increase throughput.
|
||||
gro gro.GRO
|
||||
|
||||
// mgr is the processor goroutine manager.
|
||||
mgr *processorManager
|
||||
}
|
||||
|
||||
func newRecvMMsgDispatcher(fd int, e *endpoint, opts *Options) (linkDispatcher, error) {
|
||||
stopFD, err := stopfd.New()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var batchSize int
|
||||
if opts.MultiPendingPackets {
|
||||
batchSize = int((512*1024)/(opts.MTU)) + 1
|
||||
} else {
|
||||
batchSize = 1
|
||||
}
|
||||
d := &recvMMsgDispatcher{
|
||||
StopFD: stopFD,
|
||||
fd: fd,
|
||||
e: e,
|
||||
bufs: make([]*iovecBuffer, batchSize),
|
||||
msgHdrs: make([]rawfile.MsgHdrX, batchSize),
|
||||
}
|
||||
for i := range d.bufs {
|
||||
d.bufs[i] = newIovecBuffer(opts.MTU)
|
||||
}
|
||||
d.gro.Init(false)
|
||||
d.mgr = newProcessorManager(opts, e)
|
||||
d.mgr.start()
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *recvMMsgDispatcher) release() {
|
||||
for _, iov := range d.bufs {
|
||||
iov.release()
|
||||
}
|
||||
d.mgr.close()
|
||||
}
|
||||
|
||||
// recvMMsgDispatch reads more than one packet at a time from the file
|
||||
// descriptor and dispatches it.
|
||||
func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
|
||||
// Fill message headers.
|
||||
for k := range d.msgHdrs {
|
||||
iovecs := d.bufs[k].nextIovecs()
|
||||
iovLen := len(iovecs)
|
||||
// Cannot clear only the length field. Older versions of the darwin kernel will check whether other data is empty.
|
||||
// https://github.com/Darm64/XNU/blob/xnu-2782.40.9/bsd/kern/uipc_syscalls.c#L2026-L2048
|
||||
d.msgHdrs[k] = rawfile.MsgHdrX{}
|
||||
d.msgHdrs[k].Msg.Iov = &iovecs[0]
|
||||
d.msgHdrs[k].Msg.SetIovlen(iovLen)
|
||||
}
|
||||
|
||||
nMsgs, errno := rawfile.BlockingRecvMMsgUntilStopped(d.ReadFD, d.fd, d.msgHdrs)
|
||||
if errno != 0 {
|
||||
return false, TranslateErrno(errno)
|
||||
}
|
||||
if nMsgs == -1 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Process each of received packets.
|
||||
|
||||
d.e.mu.RLock()
|
||||
addr := d.e.addr
|
||||
dsp := d.e.dispatcher
|
||||
d.e.mu.RUnlock()
|
||||
|
||||
d.gro.Dispatcher = dsp
|
||||
defer d.pkts.Reset()
|
||||
|
||||
for k := 0; k < nMsgs; k++ {
|
||||
n := int(d.msgHdrs[k].DataLen)
|
||||
payload := d.bufs[k].pullBuffer(n)
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: payload,
|
||||
})
|
||||
d.pkts.PushBack(pkt)
|
||||
|
||||
// Mark that this iovec has been processed.
|
||||
d.msgHdrs[k].Msg.Iovlen = 0
|
||||
|
||||
if d.e.parseInboundHeader(pkt, addr) {
|
||||
pkt.RXChecksumValidated = d.e.caps&stack.CapabilityRXChecksumOffload != 0
|
||||
d.mgr.queuePacket(pkt, d.e.hdrSize > 0)
|
||||
}
|
||||
}
|
||||
d.mgr.wakeReady()
|
||||
|
||||
return true, nil
|
||||
}
|
||||
64
internal/fdbased_darwin/processor_mutex.go
Normal file
64
internal/fdbased_darwin/processor_mutex.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package fdbased
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/sync"
|
||||
"github.com/sagernet/gvisor/pkg/sync/locking"
|
||||
)
|
||||
|
||||
// Mutex is sync.Mutex with the correctness validator.
|
||||
type processorMutex struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var processorprefixIndex *locking.MutexClass
|
||||
|
||||
// lockNames is a list of user-friendly lock names.
|
||||
// Populated in init.
|
||||
var processorlockNames []string
|
||||
|
||||
// lockNameIndex is used as an index passed to NestedLock and NestedUnlock,
|
||||
// referring to an index within lockNames.
|
||||
// Values are specified using the "consts" field of go_template_instance.
|
||||
type processorlockNameIndex int
|
||||
|
||||
// DO NOT REMOVE: The following function automatically replaced with lock index constants.
|
||||
// LOCK_NAME_INDEX_CONSTANTS
|
||||
const ()
|
||||
|
||||
// Lock locks m.
|
||||
// +checklocksignore
|
||||
func (m *processorMutex) Lock() {
|
||||
locking.AddGLock(processorprefixIndex, -1)
|
||||
m.mu.Lock()
|
||||
}
|
||||
|
||||
// NestedLock locks m knowing that another lock of the same type is held.
|
||||
// +checklocksignore
|
||||
func (m *processorMutex) NestedLock(i processorlockNameIndex) {
|
||||
locking.AddGLock(processorprefixIndex, int(i))
|
||||
m.mu.Lock()
|
||||
}
|
||||
|
||||
// Unlock unlocks m.
|
||||
// +checklocksignore
|
||||
func (m *processorMutex) Unlock() {
|
||||
locking.DelGLock(processorprefixIndex, -1)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// NestedUnlock unlocks m knowing that another lock of the same type is held.
|
||||
// +checklocksignore
|
||||
func (m *processorMutex) NestedUnlock(i processorlockNameIndex) {
|
||||
locking.DelGLock(processorprefixIndex, int(i))
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// DO NOT REMOVE: The following function is automatically replaced.
|
||||
func processorinitLockNames() {}
|
||||
|
||||
func init() {
|
||||
processorinitLockNames()
|
||||
processorprefixIndex = locking.NewMutexClass(reflect.TypeOf(processorMutex{}), processorlockNames)
|
||||
}
|
||||
275
internal/fdbased_darwin/processors.go
Normal file
275
internal/fdbased_darwin/processors.go
Normal file
@@ -0,0 +1,275 @@
|
||||
// Copyright 2024 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fdbased
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/rand"
|
||||
"github.com/sagernet/gvisor/pkg/sleep"
|
||||
"github.com/sagernet/gvisor/pkg/sync"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/hash/jenkins"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack/gro"
|
||||
)
|
||||
|
||||
// +stateify savable
|
||||
type processor struct {
|
||||
mu processorMutex `state:"nosave"`
|
||||
// +checklocks:mu
|
||||
pkts stack.PacketBufferList
|
||||
|
||||
e *endpoint
|
||||
gro gro.GRO
|
||||
sleeper sleep.Sleeper
|
||||
packetWaker sleep.Waker
|
||||
closeWaker sleep.Waker
|
||||
}
|
||||
|
||||
func (p *processor) start(wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
defer p.sleeper.Done()
|
||||
for {
|
||||
switch w := p.sleeper.Fetch(true); {
|
||||
case w == &p.packetWaker:
|
||||
p.deliverPackets()
|
||||
case w == &p.closeWaker:
|
||||
p.mu.Lock()
|
||||
p.pkts.Reset()
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *processor) deliverPackets() {
|
||||
p.e.mu.RLock()
|
||||
p.gro.Dispatcher = p.e.dispatcher
|
||||
p.e.mu.RUnlock()
|
||||
if p.gro.Dispatcher == nil {
|
||||
p.mu.Lock()
|
||||
p.pkts.Reset()
|
||||
p.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
for p.pkts.Len() > 0 {
|
||||
pkt := p.pkts.PopFront()
|
||||
p.mu.Unlock()
|
||||
p.gro.Enqueue(pkt)
|
||||
pkt.DecRef()
|
||||
p.mu.Lock()
|
||||
}
|
||||
p.mu.Unlock()
|
||||
p.gro.Flush()
|
||||
}
|
||||
|
||||
// processorManager handles starting, closing, and queuing packets on processor
|
||||
// goroutines.
|
||||
//
|
||||
// +stateify savable
|
||||
type processorManager struct {
|
||||
processors []processor
|
||||
seed uint32
|
||||
wg sync.WaitGroup `state:"nosave"`
|
||||
e *endpoint
|
||||
ready []bool
|
||||
}
|
||||
|
||||
// newProcessorManager creates a new processor manager.
|
||||
func newProcessorManager(opts *Options, e *endpoint) *processorManager {
|
||||
m := &processorManager{}
|
||||
m.seed = rand.Uint32()
|
||||
m.ready = make([]bool, opts.ProcessorsPerChannel)
|
||||
m.processors = make([]processor, opts.ProcessorsPerChannel)
|
||||
m.e = e
|
||||
m.wg.Add(opts.ProcessorsPerChannel)
|
||||
|
||||
for i := range m.processors {
|
||||
p := &m.processors[i]
|
||||
p.sleeper.AddWaker(&p.packetWaker)
|
||||
p.sleeper.AddWaker(&p.closeWaker)
|
||||
p.gro.Init(false)
|
||||
p.e = e
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// start starts the processor goroutines if the processor manager is configured
|
||||
// with more than one processor.
|
||||
func (m *processorManager) start() {
|
||||
for i := range m.processors {
|
||||
p := &m.processors[i]
|
||||
// Only start processor in a separate goroutine if we have multiple of them.
|
||||
if len(m.processors) > 1 {
|
||||
go p.start(&m.wg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// afterLoad is invoked by stateify.
|
||||
func (m *processorManager) afterLoad(context.Context) {
|
||||
m.wg.Add(len(m.processors))
|
||||
m.start()
|
||||
}
|
||||
|
||||
func (m *processorManager) connectionHash(cid *connectionID) uint32 {
|
||||
var payload [4]byte
|
||||
binary.LittleEndian.PutUint16(payload[0:], cid.srcPort)
|
||||
binary.LittleEndian.PutUint16(payload[2:], cid.dstPort)
|
||||
|
||||
h := jenkins.Sum32(m.seed)
|
||||
h.Write(payload[:])
|
||||
h.Write(cid.srcAddr)
|
||||
h.Write(cid.dstAddr)
|
||||
return h.Sum32()
|
||||
}
|
||||
|
||||
// queuePacket queues a packet to be delivered to the appropriate processor.
|
||||
func (m *processorManager) queuePacket(pkt *stack.PacketBuffer, hasEthHeader bool) {
|
||||
var pIdx uint32
|
||||
cid, nonConnectionPkt := tcpipConnectionID(pkt)
|
||||
if !hasEthHeader {
|
||||
if nonConnectionPkt {
|
||||
// If there's no eth header this should be a standard tcpip packet. If
|
||||
// it isn't the packet is invalid so drop it.
|
||||
return
|
||||
}
|
||||
pkt.NetworkProtocolNumber = cid.proto
|
||||
}
|
||||
if len(m.processors) == 1 || nonConnectionPkt {
|
||||
// If the packet is not associated with an active connection, use the
|
||||
// first processor.
|
||||
pIdx = 0
|
||||
} else {
|
||||
pIdx = m.connectionHash(&cid) % uint32(len(m.processors))
|
||||
}
|
||||
p := &m.processors[pIdx]
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.pkts.PushBack(pkt.IncRef())
|
||||
m.ready[pIdx] = true
|
||||
}
|
||||
|
||||
type connectionID struct {
|
||||
srcAddr, dstAddr []byte
|
||||
srcPort, dstPort uint16
|
||||
proto tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
// tcpipConnectionID returns a tcpip connection id tuple based on the data found
|
||||
// in the packet. It returns true if the packet is not associated with an active
|
||||
// connection (e.g ARP, NDP, etc). The method assumes link headers have already
|
||||
// been processed if they were present.
|
||||
func tcpipConnectionID(pkt *stack.PacketBuffer) (connectionID, bool) {
|
||||
var cid connectionID
|
||||
h, ok := pkt.Data().PullUp(1)
|
||||
if !ok {
|
||||
// Skip this packet.
|
||||
return cid, true
|
||||
}
|
||||
|
||||
const tcpSrcDstPortLen = 4
|
||||
switch header.IPVersion(h) {
|
||||
case header.IPv4Version:
|
||||
hdrLen := header.IPv4(h).HeaderLength()
|
||||
h, ok = pkt.Data().PullUp(int(hdrLen) + tcpSrcDstPortLen)
|
||||
if !ok {
|
||||
return cid, true
|
||||
}
|
||||
ipHdr := header.IPv4(h[:hdrLen])
|
||||
tcpHdr := header.TCP(h[hdrLen:][:tcpSrcDstPortLen])
|
||||
|
||||
cid.srcAddr = ipHdr.SourceAddressSlice()
|
||||
cid.dstAddr = ipHdr.DestinationAddressSlice()
|
||||
// All fragment packets need to be processed by the same goroutine, so
|
||||
// only record the TCP ports if this is not a fragment packet.
|
||||
if ipHdr.IsValid(pkt.Data().Size()) && !ipHdr.More() && ipHdr.FragmentOffset() == 0 {
|
||||
cid.srcPort = tcpHdr.SourcePort()
|
||||
cid.dstPort = tcpHdr.DestinationPort()
|
||||
}
|
||||
cid.proto = header.IPv4ProtocolNumber
|
||||
case header.IPv6Version:
|
||||
h, ok = pkt.Data().PullUp(header.IPv6FixedHeaderSize + tcpSrcDstPortLen)
|
||||
if !ok {
|
||||
return cid, true
|
||||
}
|
||||
ipHdr := header.IPv6(h)
|
||||
|
||||
var tcpHdr header.TCP
|
||||
if tcpip.TransportProtocolNumber(ipHdr.NextHeader()) == header.TCPProtocolNumber {
|
||||
tcpHdr = header.TCP(h[header.IPv6FixedHeaderSize:][:tcpSrcDstPortLen])
|
||||
} else {
|
||||
// Slow path for IPv6 extension headers :(.
|
||||
dataBuf := pkt.Data().ToBuffer()
|
||||
dataBuf.TrimFront(header.IPv6MinimumSize)
|
||||
it := header.MakeIPv6PayloadIterator(header.IPv6ExtensionHeaderIdentifier(ipHdr.NextHeader()), dataBuf)
|
||||
defer it.Release()
|
||||
for {
|
||||
hdr, done, err := it.Next()
|
||||
if done || err != nil {
|
||||
break
|
||||
}
|
||||
hdr.Release()
|
||||
}
|
||||
h, ok = pkt.Data().PullUp(int(it.HeaderOffset()) + tcpSrcDstPortLen)
|
||||
if !ok {
|
||||
return cid, true
|
||||
}
|
||||
tcpHdr = header.TCP(h[it.HeaderOffset():][:tcpSrcDstPortLen])
|
||||
}
|
||||
cid.srcAddr = ipHdr.SourceAddressSlice()
|
||||
cid.dstAddr = ipHdr.DestinationAddressSlice()
|
||||
cid.srcPort = tcpHdr.SourcePort()
|
||||
cid.dstPort = tcpHdr.DestinationPort()
|
||||
cid.proto = header.IPv6ProtocolNumber
|
||||
default:
|
||||
return cid, true
|
||||
}
|
||||
return cid, false
|
||||
}
|
||||
|
||||
func (m *processorManager) close() {
|
||||
if len(m.processors) < 2 {
|
||||
return
|
||||
}
|
||||
for i := range m.processors {
|
||||
p := &m.processors[i]
|
||||
p.closeWaker.Assert()
|
||||
}
|
||||
}
|
||||
|
||||
// wakeReady wakes up all processors that have a packet queued. If there is only
|
||||
// one processor, the method delivers the packet inline without waking a
|
||||
// goroutine.
|
||||
func (m *processorManager) wakeReady() {
|
||||
for i, ready := range m.ready {
|
||||
if !ready {
|
||||
continue
|
||||
}
|
||||
p := &m.processors[i]
|
||||
if len(m.processors) > 1 {
|
||||
p.packetWaker.Assert()
|
||||
} else {
|
||||
p.deliverPackets()
|
||||
}
|
||||
m.ready[i] = false
|
||||
}
|
||||
}
|
||||
@@ -38,3 +38,8 @@ func Combine(a, b uint16) uint16 {
|
||||
v := uint32(a) + uint32(b)
|
||||
return uint16(v + v>>16)
|
||||
}
|
||||
|
||||
func ChecksumDefault(buf []byte, initial uint16) uint16 {
|
||||
s, _ := calculateChecksum(buf, false, initial)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -8,6 +8,5 @@ package checksum
|
||||
//
|
||||
// The initial checksum must have been computed on an even number of bytes.
|
||||
func Checksum(buf []byte, initial uint16) uint16 {
|
||||
s, _ := calculateChecksum(buf, false, initial)
|
||||
return s
|
||||
return ChecksumDefault(buf, initial)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build !amd64
|
||||
|
||||
// Copyright 2023 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
@@ -458,7 +458,10 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
|
||||
|
||||
// CalculateChecksum calculates the checksum of the IPv4 header.
|
||||
func (b IPv4) CalculateChecksum() uint16 {
|
||||
return checksum.Checksum(b[:b.HeaderLength()], 0)
|
||||
// return checksum.Checksum(b[:b.HeaderLength()], 0)
|
||||
xsum0 := checksum.Checksum(b[:xsum], 0)
|
||||
xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0)
|
||||
return xsum0
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the IPv4 header.
|
||||
@@ -550,7 +553,8 @@ func (b IPv4) IsChecksumValid() bool {
|
||||
// same set of octets, including the checksum field. If the result
|
||||
// is all 1 bits (-0 in 1's complement arithmetic), the check
|
||||
// succeeds.
|
||||
return b.CalculateChecksum() == 0xffff
|
||||
//return b.CalculateChecksum() == 0xffff
|
||||
return checksum.Checksum(b[:b.HeaderLength()], 0) == 0xffff
|
||||
}
|
||||
|
||||
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
|
||||
|
||||
@@ -18,10 +18,8 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
@@ -145,79 +143,6 @@ func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) i
|
||||
return ((padLen + align - 1) & ^(align - 1)) - padLen
|
||||
}
|
||||
|
||||
// IPv6PayloadHeader is implemented by the various headers that can be found
|
||||
// in an IPv6 payload.
|
||||
//
|
||||
// These headers include IPv6 extension headers or upper layer data.
|
||||
type IPv6PayloadHeader interface {
|
||||
isIPv6PayloadHeader()
|
||||
|
||||
// Release frees all resources held by the header.
|
||||
Release()
|
||||
}
|
||||
|
||||
// IPv6RawPayloadHeader the remainder of an IPv6 payload after an iterator
|
||||
// encounters a Next Header field it does not recognize as an IPv6 extension
|
||||
// header. The caller is responsible for releasing the underlying buffer after
|
||||
// it's no longer needed.
|
||||
type IPv6RawPayloadHeader struct {
|
||||
Identifier IPv6ExtensionHeaderIdentifier
|
||||
Buf buffer.Buffer
|
||||
}
|
||||
|
||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
||||
func (IPv6RawPayloadHeader) isIPv6PayloadHeader() {}
|
||||
|
||||
// Release implements IPv6PayloadHeader.Release.
|
||||
func (i IPv6RawPayloadHeader) Release() {
|
||||
i.Buf.Release()
|
||||
}
|
||||
|
||||
// ipv6OptionsExtHdr is an IPv6 extension header that holds options.
|
||||
type ipv6OptionsExtHdr struct {
|
||||
buf *buffer.View
|
||||
}
|
||||
|
||||
// Release implements IPv6PayloadHeader.Release.
|
||||
func (i ipv6OptionsExtHdr) Release() {
|
||||
if i.buf != nil {
|
||||
i.buf.Release()
|
||||
}
|
||||
}
|
||||
|
||||
// Iter returns an iterator over the IPv6 extension header options held in b.
|
||||
func (i ipv6OptionsExtHdr) Iter() IPv6OptionsExtHdrOptionsIterator {
|
||||
it := IPv6OptionsExtHdrOptionsIterator{}
|
||||
it.reader = i.buf
|
||||
return it
|
||||
}
|
||||
|
||||
// IPv6OptionsExtHdrOptionsIterator is an iterator over IPv6 extension header
|
||||
// options.
|
||||
//
|
||||
// Note, between when an IPv6OptionsExtHdrOptionsIterator is obtained and last
|
||||
// used, no changes to the underlying buffer may happen. Doing so may cause
|
||||
// undefined and unexpected behaviour. It is fine to obtain an
|
||||
// IPv6OptionsExtHdrOptionsIterator, iterate over the first few options then
|
||||
// modify the backing payload so long as the IPv6OptionsExtHdrOptionsIterator
|
||||
// obtained before modification is no longer used.
|
||||
type IPv6OptionsExtHdrOptionsIterator struct {
|
||||
reader *buffer.View
|
||||
|
||||
// optionOffset is the number of bytes from the first byte of the
|
||||
// options field to the beginning of the current option.
|
||||
optionOffset uint32
|
||||
|
||||
// nextOptionOffset is the offset of the next option.
|
||||
nextOptionOffset uint32
|
||||
}
|
||||
|
||||
// OptionOffset returns the number of bytes parsed while processing the
|
||||
// option field of the current Extension Header.
|
||||
func (i *IPv6OptionsExtHdrOptionsIterator) OptionOffset() uint32 {
|
||||
return i.optionOffset
|
||||
}
|
||||
|
||||
// IPv6OptionUnknownAction is the action that must be taken if the processing
|
||||
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
|
||||
type IPv6OptionUnknownAction int
|
||||
@@ -294,143 +219,6 @@ func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUn
|
||||
// is malformed.
|
||||
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
|
||||
|
||||
// IPv6UnknownExtHdrOption holds the identifier and data for an IPv6 extension
|
||||
// header option that is unknown by the parsing utilities.
|
||||
type IPv6UnknownExtHdrOption struct {
|
||||
Identifier IPv6ExtHdrOptionIdentifier
|
||||
Data *buffer.View
|
||||
}
|
||||
|
||||
// UnknownAction implements IPv6OptionUnknownAction.UnknownAction.
|
||||
func (o *IPv6UnknownExtHdrOption) UnknownAction() IPv6OptionUnknownAction {
|
||||
return ipv6UnknownActionFromIdentifier(o.Identifier)
|
||||
}
|
||||
|
||||
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.isIPv6ExtHdrOption.
|
||||
func (*IPv6UnknownExtHdrOption) isIPv6ExtHdrOption() {}
|
||||
|
||||
// Next returns the next option in the options data.
|
||||
//
|
||||
// If the next item is not a known extension header option,
|
||||
// IPv6UnknownExtHdrOption will be returned with the option identifier and data.
|
||||
//
|
||||
// The return is of the format (option, done, error). done will be true when
|
||||
// Next is unable to return anything because the iterator has reached the end of
|
||||
// the options data, or an error occurred.
|
||||
func (i *IPv6OptionsExtHdrOptionsIterator) Next() (IPv6ExtHdrOption, bool, error) {
|
||||
for {
|
||||
i.optionOffset = i.nextOptionOffset
|
||||
temp, err := i.reader.ReadByte()
|
||||
if err != nil {
|
||||
// If we can't read the first byte of a new option, then we know the
|
||||
// options buffer has been exhausted and we are done iterating.
|
||||
return nil, true, nil
|
||||
}
|
||||
id := IPv6ExtHdrOptionIdentifier(temp)
|
||||
|
||||
// If the option identifier indicates the option is a Pad1 option, then we
|
||||
// know the option does not have Length and Data fields. End processing of
|
||||
// the Pad1 option and continue processing the buffer as a new option.
|
||||
if id == ipv6Pad1ExtHdrOptionIdentifier {
|
||||
i.nextOptionOffset = i.optionOffset + 1
|
||||
continue
|
||||
}
|
||||
|
||||
length, err := i.reader.ReadByte()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
// ReadByte should only ever return nil or io.EOF.
|
||||
panic(fmt.Sprintf("unexpected error when reading the option's Length field for option with id = %d: %s", id, err))
|
||||
}
|
||||
|
||||
// We use io.ErrUnexpectedEOF as exhausting the buffer is unexpected once
|
||||
// we start parsing an option; we expect the reader to contain enough
|
||||
// bytes for the whole option.
|
||||
return nil, true, fmt.Errorf("error when reading the option's Length field for option with id = %d: %w", id, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
// Do we have enough bytes in the reader for the next option?
|
||||
if n := i.reader.Size(); n < int(length) {
|
||||
// Consume the remaining buffer.
|
||||
i.reader.TrimFront(i.reader.Size())
|
||||
|
||||
// We return the same error as if we failed to read a non-padding option
|
||||
// so consumers of this iterator don't need to differentiate between
|
||||
// padding and non-padding options.
|
||||
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
i.nextOptionOffset = i.optionOffset + uint32(length) + 1 /* option ID */ + 1 /* length byte */
|
||||
|
||||
switch id {
|
||||
case ipv6PadNExtHdrOptionIdentifier:
|
||||
// Special-case the variable length padding option to avoid a copy.
|
||||
i.reader.TrimFront(int(length))
|
||||
continue
|
||||
case ipv6RouterAlertHopByHopOptionIdentifier:
|
||||
var routerAlertValue [ipv6RouterAlertPayloadLength]byte
|
||||
if n, err := io.ReadFull(i.reader, routerAlertValue[:]); err != nil {
|
||||
switch err {
|
||||
case io.EOF, io.ErrUnexpectedEOF:
|
||||
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
|
||||
default:
|
||||
return nil, true, fmt.Errorf("read %d out of %d option data bytes for router alert option: %w", n, ipv6RouterAlertPayloadLength, err)
|
||||
}
|
||||
} else if n != int(length) {
|
||||
return nil, true, fmt.Errorf("got invalid length (%d) for router alert option (want = %d): %w", length, ipv6RouterAlertPayloadLength, ErrMalformedIPv6ExtHdrOption)
|
||||
}
|
||||
return &IPv6RouterAlertOption{Value: IPv6RouterAlertValue(binary.BigEndian.Uint16(routerAlertValue[:]))}, false, nil
|
||||
default:
|
||||
bytes := buffer.NewView(int(length))
|
||||
if n, err := io.CopyN(bytes, i.reader, int64(length)); err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
|
||||
return nil, true, fmt.Errorf("read %d out of %d option data bytes for option with id = %d: %w", n, length, id, err)
|
||||
}
|
||||
return &IPv6UnknownExtHdrOption{Identifier: id, Data: bytes}, false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IPv6HopByHopOptionsExtHdr is a buffer holding the Hop By Hop Options
|
||||
// extension header.
|
||||
type IPv6HopByHopOptionsExtHdr struct {
|
||||
ipv6OptionsExtHdr
|
||||
}
|
||||
|
||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
||||
func (IPv6HopByHopOptionsExtHdr) isIPv6PayloadHeader() {}
|
||||
|
||||
// IPv6DestinationOptionsExtHdr is a buffer holding the Destination Options
|
||||
// extension header.
|
||||
type IPv6DestinationOptionsExtHdr struct {
|
||||
ipv6OptionsExtHdr
|
||||
}
|
||||
|
||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
||||
func (IPv6DestinationOptionsExtHdr) isIPv6PayloadHeader() {}
|
||||
|
||||
// IPv6RoutingExtHdr is a buffer holding the Routing extension header specific
|
||||
// data as outlined in RFC 8200 section 4.4.
|
||||
type IPv6RoutingExtHdr struct {
|
||||
Buf *buffer.View
|
||||
}
|
||||
|
||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
||||
func (IPv6RoutingExtHdr) isIPv6PayloadHeader() {}
|
||||
|
||||
// Release implements IPv6PayloadHeader.Release.
|
||||
func (b IPv6RoutingExtHdr) Release() {
|
||||
b.Buf.Release()
|
||||
}
|
||||
|
||||
// SegmentsLeft returns the Segments Left field.
|
||||
func (b IPv6RoutingExtHdr) SegmentsLeft() uint8 {
|
||||
return b.Buf.AsSlice()[ipv6RoutingExtHdrSegmentsLeftIdx]
|
||||
}
|
||||
|
||||
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
|
||||
// data as outlined in RFC 8200 section 4.5.
|
||||
//
|
||||
@@ -473,242 +261,6 @@ func (b IPv6FragmentExtHdr) IsAtomic() bool {
|
||||
return !b.More() && b.FragmentOffset() == 0
|
||||
}
|
||||
|
||||
// IPv6PayloadIterator is an iterator over the contents of an IPv6 payload.
|
||||
//
|
||||
// The IPv6 payload may contain IPv6 extension headers before any upper layer
|
||||
// data.
|
||||
//
|
||||
// Note, between when an IPv6PayloadIterator is obtained and last used, no
|
||||
// changes to the payload may happen. Doing so may cause undefined and
|
||||
// unexpected behaviour. It is fine to obtain an IPv6PayloadIterator, iterate
|
||||
// over the first few headers then modify the backing payload so long as the
|
||||
// IPv6PayloadIterator obtained before modification is no longer used.
|
||||
type IPv6PayloadIterator struct {
|
||||
// The identifier of the next header to parse.
|
||||
nextHdrIdentifier IPv6ExtensionHeaderIdentifier
|
||||
|
||||
payload buffer.Buffer
|
||||
|
||||
// Indicates to the iterator that it should return the remaining payload as a
|
||||
// raw payload on the next call to Next.
|
||||
forceRaw bool
|
||||
|
||||
// headerOffset is the offset of the beginning of the current extension
|
||||
// header starting from the beginning of the fixed header.
|
||||
headerOffset uint32
|
||||
|
||||
// parseOffset is the byte offset into the current extension header of the
|
||||
// field we are currently examining. It can be added to the header offset
|
||||
// if the absolute offset within the packet is required.
|
||||
parseOffset uint32
|
||||
|
||||
// nextOffset is the offset of the next header.
|
||||
nextOffset uint32
|
||||
}
|
||||
|
||||
// HeaderOffset returns the offset to the start of the extension
|
||||
// header most recently processed.
|
||||
func (i IPv6PayloadIterator) HeaderOffset() uint32 {
|
||||
return i.headerOffset
|
||||
}
|
||||
|
||||
// ParseOffset returns the number of bytes successfully parsed.
|
||||
func (i IPv6PayloadIterator) ParseOffset() uint32 {
|
||||
return i.headerOffset + i.parseOffset
|
||||
}
|
||||
|
||||
// MakeIPv6PayloadIterator returns an iterator over the IPv6 payload containing
|
||||
// extension headers, or a raw payload if the payload cannot be parsed. The
|
||||
// iterator takes ownership of the payload.
|
||||
func MakeIPv6PayloadIterator(nextHdrIdentifier IPv6ExtensionHeaderIdentifier, payload buffer.Buffer) IPv6PayloadIterator {
|
||||
return IPv6PayloadIterator{
|
||||
nextHdrIdentifier: nextHdrIdentifier,
|
||||
payload: payload,
|
||||
nextOffset: IPv6FixedHeaderSize,
|
||||
}
|
||||
}
|
||||
|
||||
// Release frees the resources owned by the iterator.
|
||||
func (i *IPv6PayloadIterator) Release() {
|
||||
i.payload.Release()
|
||||
}
|
||||
|
||||
// AsRawHeader returns the remaining payload of i as a raw header and
|
||||
// optionally consumes the iterator.
|
||||
//
|
||||
// If consume is true, calls to Next after calling AsRawHeader on i will
|
||||
// indicate that the iterator is done. The returned header takes ownership of
|
||||
// its payload.
|
||||
func (i *IPv6PayloadIterator) AsRawHeader(consume bool) IPv6RawPayloadHeader {
|
||||
identifier := i.nextHdrIdentifier
|
||||
|
||||
var buf buffer.Buffer
|
||||
if consume {
|
||||
// Since we consume the iterator, we return the payload as is.
|
||||
buf = i.payload
|
||||
|
||||
// Mark i as done, but keep track of where we were for error reporting.
|
||||
*i = IPv6PayloadIterator{
|
||||
nextHdrIdentifier: IPv6NoNextHeaderIdentifier,
|
||||
headerOffset: i.headerOffset,
|
||||
nextOffset: i.nextOffset,
|
||||
}
|
||||
} else {
|
||||
buf = i.payload.Clone()
|
||||
}
|
||||
|
||||
return IPv6RawPayloadHeader{Identifier: identifier, Buf: buf}
|
||||
}
|
||||
|
||||
// Next returns the next item in the payload.
|
||||
//
|
||||
// If the next item is not a known IPv6 extension header, IPv6RawPayloadHeader
|
||||
// will be returned with the remaining bytes and next header identifier.
|
||||
//
|
||||
// The return is of the format (header, done, error). done will be true when
|
||||
// Next is unable to return anything because the iterator has reached the end of
|
||||
// the payload, or an error occurred.
|
||||
func (i *IPv6PayloadIterator) Next() (IPv6PayloadHeader, bool, error) {
|
||||
i.headerOffset = i.nextOffset
|
||||
i.parseOffset = 0
|
||||
// We could be forced to return i as a raw header when the previous header was
|
||||
// a fragment extension header as the data following the fragment extension
|
||||
// header may not be complete.
|
||||
if i.forceRaw {
|
||||
return i.AsRawHeader(true /* consume */), false, nil
|
||||
}
|
||||
|
||||
// Is the header we are parsing a known extension header?
|
||||
switch i.nextHdrIdentifier {
|
||||
case IPv6HopByHopOptionsExtHdrIdentifier:
|
||||
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
i.nextHdrIdentifier = nextHdrIdentifier
|
||||
return IPv6HopByHopOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
|
||||
case IPv6RoutingExtHdrIdentifier:
|
||||
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
i.nextHdrIdentifier = nextHdrIdentifier
|
||||
return IPv6RoutingExtHdr{view}, false, nil
|
||||
case IPv6FragmentExtHdrIdentifier:
|
||||
var data [6]byte
|
||||
// We ignore the returned bytes because we know the fragment extension
|
||||
// header specific data will fit in data.
|
||||
nextHdrIdentifier, _, err := i.nextHeaderData(true /* fragmentHdr */, data[:])
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
fragmentExtHdr := IPv6FragmentExtHdr(data)
|
||||
|
||||
// If the packet is not the first fragment, do not attempt to parse anything
|
||||
// after the fragment extension header as the payload following the fragment
|
||||
// extension header should not contain any headers; the first fragment must
|
||||
// hold all the headers up to and including any upper layer headers, as per
|
||||
// RFC 8200 section 4.5.
|
||||
if fragmentExtHdr.FragmentOffset() != 0 {
|
||||
i.forceRaw = true
|
||||
}
|
||||
|
||||
i.nextHdrIdentifier = nextHdrIdentifier
|
||||
return fragmentExtHdr, false, nil
|
||||
case IPv6DestinationOptionsExtHdrIdentifier:
|
||||
nextHdrIdentifier, view, err := i.nextHeaderData(false /* fragmentHdr */, nil)
|
||||
if err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
|
||||
i.nextHdrIdentifier = nextHdrIdentifier
|
||||
return IPv6DestinationOptionsExtHdr{ipv6OptionsExtHdr{view}}, false, nil
|
||||
case IPv6NoNextHeaderIdentifier:
|
||||
// This indicates the end of the IPv6 payload.
|
||||
return nil, true, nil
|
||||
|
||||
default:
|
||||
// The header we are parsing is not a known extension header. Return the
|
||||
// raw payload.
|
||||
return i.AsRawHeader(true /* consume */), false, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NextHeaderIdentifier returns the identifier of the header next returned by
|
||||
// it.Next().
|
||||
func (i *IPv6PayloadIterator) NextHeaderIdentifier() IPv6ExtensionHeaderIdentifier {
|
||||
return i.nextHdrIdentifier
|
||||
}
|
||||
|
||||
// nextHeaderData returns the extension header's Next Header field and raw data.
|
||||
//
|
||||
// fragmentHdr indicates that the extension header being parsed is the Fragment
|
||||
// extension header so the Length field should be ignored as it is Reserved
|
||||
// for the Fragment extension header.
|
||||
//
|
||||
// If bytes is not nil, extension header specific data will be read into bytes
|
||||
// if it has enough capacity. If bytes is provided but does not have enough
|
||||
// capacity for the data, nextHeaderData will panic.
|
||||
func (i *IPv6PayloadIterator) nextHeaderData(fragmentHdr bool, bytes []byte) (IPv6ExtensionHeaderIdentifier, *buffer.View, error) {
|
||||
// We ignore the number of bytes read because we know we will only ever read
|
||||
// at max 1 bytes since rune has a length of 1. If we read 0 bytes, the Read
|
||||
// would return io.EOF to indicate that io.Reader has reached the end of the
|
||||
// payload.
|
||||
rdr := i.payload.AsBufferReader()
|
||||
nextHdrIdentifier, err := rdr.ReadByte()
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("error when reading the Next Header field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
|
||||
}
|
||||
i.parseOffset++
|
||||
|
||||
var length uint8
|
||||
length, err = rdr.ReadByte()
|
||||
if err != nil {
|
||||
if fragmentHdr {
|
||||
return 0, nil, fmt.Errorf("error when reading the Length field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
|
||||
}
|
||||
|
||||
return 0, nil, fmt.Errorf("error when reading the Reserved field for extension header with id = %d: %w", i.nextHdrIdentifier, err)
|
||||
}
|
||||
if fragmentHdr {
|
||||
length = 0
|
||||
}
|
||||
|
||||
// Make parseOffset point to the first byte of the Extension Header
|
||||
// specific data.
|
||||
i.parseOffset++
|
||||
|
||||
// length is in 8 byte chunks but doesn't include the first one.
|
||||
// See RFC 8200 for each header type, sections 4.3-4.6 and the requirement
|
||||
// in section 4.8 for new extension headers at the top of page 24.
|
||||
// [ Hdr Ext Len ] ... Length of the Destination Options header in 8-octet
|
||||
// units, not including the first 8 octets.
|
||||
i.nextOffset += uint32((length + 1) * ipv6ExtHdrLenBytesPerUnit)
|
||||
|
||||
bytesLen := int(length)*ipv6ExtHdrLenBytesPerUnit + ipv6ExtHdrLenBytesExcluded
|
||||
if fragmentHdr {
|
||||
if n := len(bytes); n < bytesLen {
|
||||
panic(fmt.Sprintf("bytes only has space for %d bytes but need space for %d bytes (length = %d) for extension header with id = %d", n, bytesLen, length, i.nextHdrIdentifier))
|
||||
}
|
||||
if n, err := io.ReadFull(&rdr, bytes); err != nil {
|
||||
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
|
||||
}
|
||||
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), nil, nil
|
||||
}
|
||||
v := buffer.NewView(bytesLen)
|
||||
if n, err := io.CopyN(v, &rdr, int64(bytesLen)); err != nil {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
v.Release()
|
||||
return 0, nil, fmt.Errorf("read %d out of %d extension header data bytes (length = %d) for header with id = %d: %w", n, bytesLen, length, i.nextHdrIdentifier, err)
|
||||
}
|
||||
return IPv6ExtensionHeaderIdentifier(nextHdrIdentifier), v, nil
|
||||
}
|
||||
|
||||
// IPv6SerializableExtHdr provides serialization for IPv6 extension
|
||||
// headers.
|
||||
type IPv6SerializableExtHdr interface {
|
||||
|
||||
@@ -351,14 +351,18 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) {
|
||||
// and the checksum of the segment data.
|
||||
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
// return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum)
|
||||
xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum)
|
||||
return xsum
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the TCP header's checksum is valid.
|
||||
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
// return b.CalculateChecksum(xsum) == 0xffff
|
||||
return checksum.Checksum(b[:b.DataOffset()], xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Options returns a slice that holds the unparsed TCP options in the segment.
|
||||
|
||||
@@ -113,15 +113,18 @@ func (b UDP) SetLength(length uint16) {
|
||||
// CalculateChecksum calculates the checksum of the UDP packet, given the
|
||||
// checksum of the network-layer pseudo-header and the checksum of the payload.
|
||||
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
// Calculate the rest of the checksum.\
|
||||
// return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
xsum := checksum.Checksum(b[:udpChecksum], partialChecksum)
|
||||
xsum = checksum.Checksum(b[udpChecksum+2:UDPMinimumSize], xsum)
|
||||
return xsum
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the UDP header's checksum is valid.
|
||||
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
return checksum.Checksum(b[:UDPMinimumSize], xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the UDP header.
|
||||
|
||||
188
internal/rawfile_darwin/rawfile.go
Normal file
188
internal/rawfile_darwin/rawfile.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package rawfile
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// SizeofIovec is the size of a unix.Iovec in bytes.
|
||||
const SizeofIovec = unsafe.Sizeof(unix.Iovec{})
|
||||
|
||||
// MaxIovs is UIO_MAXIOV, the maximum number of iovecs that may be passed to a
|
||||
// host system call in a single array.
|
||||
const MaxIovs = 1024
|
||||
|
||||
// IovecFromBytes returns a unix.Iovec representing bs.
|
||||
//
|
||||
// Preconditions: len(bs) > 0.
|
||||
func IovecFromBytes(bs []byte) unix.Iovec {
|
||||
iov := unix.Iovec{
|
||||
Base: &bs[0],
|
||||
}
|
||||
iov.SetLen(len(bs))
|
||||
return iov
|
||||
}
|
||||
|
||||
func bytesFromIovec(iov unix.Iovec) (bs []byte) {
|
||||
sh := (*reflect.SliceHeader)(unsafe.Pointer(&bs))
|
||||
sh.Data = uintptr(unsafe.Pointer(iov.Base))
|
||||
sh.Len = int(iov.Len)
|
||||
sh.Cap = int(iov.Len)
|
||||
return
|
||||
}
|
||||
|
||||
// AppendIovecFromBytes returns append(iovs, IovecFromBytes(bs)). If len(bs) ==
|
||||
// 0, AppendIovecFromBytes returns iovs without modification. If len(iovs) >=
|
||||
// max, AppendIovecFromBytes replaces the final iovec in iovs with one that
|
||||
// also includes the contents of bs. Note that this implies that
|
||||
// AppendIovecFromBytes is only usable when the returned iovec slice is used as
|
||||
// the source of a write.
|
||||
func AppendIovecFromBytes(iovs []unix.Iovec, bs []byte, max int) []unix.Iovec {
|
||||
if len(bs) == 0 {
|
||||
return iovs
|
||||
}
|
||||
if len(iovs) < max {
|
||||
return append(iovs, IovecFromBytes(bs))
|
||||
}
|
||||
iovs[len(iovs)-1] = IovecFromBytes(append(bytesFromIovec(iovs[len(iovs)-1]), bs...))
|
||||
return iovs
|
||||
}
|
||||
|
||||
type MsgHdrX struct {
|
||||
Msg unix.Msghdr
|
||||
DataLen uint32
|
||||
}
|
||||
|
||||
func NonBlockingSendMMsg(fd int, msgHdrs []MsgHdrX) (int, unix.Errno) {
|
||||
n, _, e := unix.RawSyscall6(unix.SYS_SENDMSG_X, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
|
||||
return int(n), e
|
||||
}
|
||||
|
||||
const SizeofMsgHdrX = unsafe.Sizeof(MsgHdrX{})
|
||||
|
||||
// NonBlockingWriteIovec writes iovec to a file descriptor in a single unix.
|
||||
// It fails if partial data is written.
|
||||
func NonBlockingWriteIovec(fd int, iovec []unix.Iovec) unix.Errno {
|
||||
iovecLen := uintptr(len(iovec))
|
||||
_, _, e := unix.RawSyscall(unix.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), iovecLen)
|
||||
return e
|
||||
}
|
||||
|
||||
func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, unix.Errno) {
|
||||
for {
|
||||
n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
|
||||
if e == 0 {
|
||||
return int(n), 0
|
||||
}
|
||||
if e != 0 && e != unix.EWOULDBLOCK {
|
||||
return 0, e
|
||||
}
|
||||
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
|
||||
if stopped {
|
||||
return -1, e
|
||||
}
|
||||
if e != 0 && e != unix.EINTR {
|
||||
return 0, e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MsgHdrX) (int, unix.Errno) {
|
||||
for {
|
||||
n, _, e := unix.RawSyscall6(unix.SYS_RECVMSG_X, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
|
||||
if e == 0 {
|
||||
return int(n), e
|
||||
}
|
||||
|
||||
if e != 0 && e != unix.EWOULDBLOCK {
|
||||
return 0, e
|
||||
}
|
||||
|
||||
stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
|
||||
if stopped {
|
||||
return -1, e
|
||||
}
|
||||
if e != 0 && e != unix.EINTR {
|
||||
return 0, e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) {
|
||||
// Create kqueue
|
||||
kq, err := unix.Kqueue()
|
||||
if err != nil {
|
||||
return false, unix.Errno(err.(unix.Errno))
|
||||
}
|
||||
defer unix.Close(kq)
|
||||
|
||||
// Prepare kevents for registration
|
||||
var kevents []unix.Kevent_t
|
||||
|
||||
// Always monitor efd for read events
|
||||
kevents = append(kevents, unix.Kevent_t{
|
||||
Ident: uint64(efd),
|
||||
Filter: unix.EVFILT_READ,
|
||||
Flags: unix.EV_ADD | unix.EV_ENABLE,
|
||||
})
|
||||
|
||||
// Monitor fd based on requested events
|
||||
// Convert poll events to kqueue filters
|
||||
if events&unix.POLLIN != 0 {
|
||||
kevents = append(kevents, unix.Kevent_t{
|
||||
Ident: uint64(fd),
|
||||
Filter: unix.EVFILT_READ,
|
||||
Flags: unix.EV_ADD | unix.EV_ENABLE,
|
||||
})
|
||||
}
|
||||
if events&unix.POLLOUT != 0 {
|
||||
kevents = append(kevents, unix.Kevent_t{
|
||||
Ident: uint64(fd),
|
||||
Filter: unix.EVFILT_WRITE,
|
||||
Flags: unix.EV_ADD | unix.EV_ENABLE,
|
||||
})
|
||||
}
|
||||
|
||||
// Register events
|
||||
_, err = unix.Kevent(kq, kevents, nil, nil)
|
||||
if err != nil {
|
||||
return false, unix.Errno(err.(unix.Errno))
|
||||
}
|
||||
|
||||
// Wait for events (blocking)
|
||||
revents := make([]unix.Kevent_t, len(kevents))
|
||||
n, err := unix.Kevent(kq, nil, revents, nil)
|
||||
if err != nil {
|
||||
return false, unix.Errno(err.(unix.Errno))
|
||||
}
|
||||
|
||||
// Check results
|
||||
var efdHasData bool
|
||||
var errno unix.Errno
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
ev := &revents[i]
|
||||
|
||||
if int(ev.Ident) == efd && ev.Filter == unix.EVFILT_READ {
|
||||
efdHasData = true
|
||||
}
|
||||
|
||||
if int(ev.Ident) == fd {
|
||||
// Check for errors or EOF
|
||||
if ev.Flags&unix.EV_EOF != 0 {
|
||||
errno = unix.ECONNRESET
|
||||
} else if ev.Flags&unix.EV_ERROR != 0 {
|
||||
// Extract error from Data field
|
||||
if ev.Data != 0 {
|
||||
errno = unix.Errno(ev.Data)
|
||||
} else {
|
||||
errno = unix.ECONNRESET
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return efdHasData, errno
|
||||
}
|
||||
61
internal/stopfd_darwin/stopfd.go
Normal file
61
internal/stopfd_darwin/stopfd.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package stopfd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type StopFD struct {
|
||||
ReadFD int
|
||||
WriteFD int
|
||||
}
|
||||
|
||||
func New() (StopFD, error) {
|
||||
fds := make([]int, 2)
|
||||
err := unix.Pipe(fds)
|
||||
if err != nil {
|
||||
return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to create pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := unix.SetNonblock(fds[0], true); err != nil {
|
||||
unix.Close(fds[0])
|
||||
unix.Close(fds[1])
|
||||
return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to set read end non-blocking: %w", err)
|
||||
}
|
||||
|
||||
if err := unix.SetNonblock(fds[1], true); err != nil {
|
||||
unix.Close(fds[0])
|
||||
unix.Close(fds[1])
|
||||
return StopFD{ReadFD: -1, WriteFD: -1}, fmt.Errorf("failed to set write end non-blocking: %w", err)
|
||||
}
|
||||
|
||||
return StopFD{ReadFD: fds[0], WriteFD: fds[1]}, nil
|
||||
}
|
||||
|
||||
func (sf *StopFD) Stop() {
|
||||
signal := []byte{1}
|
||||
if n, err := unix.Write(sf.WriteFD, signal); n != len(signal) || err != nil {
|
||||
panic(fmt.Sprintf("write(WriteFD) = (%d, %s), want (%d, nil)", n, err, len(signal)))
|
||||
}
|
||||
}
|
||||
|
||||
func (sf *StopFD) Close() error {
|
||||
var err1, err2 error
|
||||
if sf.ReadFD != -1 {
|
||||
err1 = unix.Close(sf.ReadFD)
|
||||
sf.ReadFD = -1
|
||||
}
|
||||
if sf.WriteFD != -1 {
|
||||
err2 = unix.Close(sf.WriteFD)
|
||||
sf.WriteFD = -1
|
||||
}
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
func (sf *StopFD) EFD() int {
|
||||
return sf.ReadFD
|
||||
}
|
||||
@@ -39,6 +39,10 @@ func closeAdapter(wintun *Adapter) {
|
||||
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
|
||||
// and hence a new NLA entry is created for each new adapter.
|
||||
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
|
||||
err = procWintunCloseAdapter.Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var name16 *uint16
|
||||
name16, err = windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -30,6 +30,8 @@ type DefaultInterfaceMonitor interface {
|
||||
AndroidVPNEnabled() bool
|
||||
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
|
||||
UnregisterCallback(element *list.Element[DefaultInterfaceUpdateCallback])
|
||||
RegisterMyInterface(interfaceName string)
|
||||
MyInterface() string
|
||||
}
|
||||
|
||||
type DefaultInterfaceMonitorOptions struct {
|
||||
|
||||
@@ -51,12 +51,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
return err
|
||||
}
|
||||
|
||||
oldInterface := m.defaultInterface.Load()
|
||||
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
||||
}
|
||||
m.defaultInterface.Store(newInterface)
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -165,12 +165,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
if defaultInterface == nil {
|
||||
return ErrNoRoute
|
||||
}
|
||||
oldInterface := m.defaultInterface.Load()
|
||||
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
|
||||
}
|
||||
m.defaultInterface.Store(newInterface)
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ type networkUpdateMonitor struct {
|
||||
var ErrNetlinkBanned = E.New(
|
||||
"netlink socket in Android is banned by Google, " +
|
||||
"use the root or system (ADB) user to run sing-box, " +
|
||||
"or switch to the sing-box Adnroid graphical interface client",
|
||||
"or switch to the sing-box Android graphical interface client",
|
||||
)
|
||||
|
||||
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
||||
|
||||
@@ -25,12 +25,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
return err
|
||||
}
|
||||
|
||||
oldInterface := m.defaultInterface.Load()
|
||||
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
||||
}
|
||||
m.defaultInterface.Store(newInterface)
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -5,9 +5,9 @@ package tun
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
@@ -42,11 +42,12 @@ type defaultInterfaceMonitor struct {
|
||||
androidVPNEnabled bool
|
||||
noRoute bool
|
||||
networkMonitor NetworkUpdateMonitor
|
||||
logger logger.Logger
|
||||
checkUpdateTimer *time.Timer
|
||||
element *list.Element[NetworkUpdateCallback]
|
||||
access sync.Mutex
|
||||
callbacks list.List[DefaultInterfaceUpdateCallback]
|
||||
logger logger.Logger
|
||||
myInterface string
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
|
||||
@@ -132,3 +133,15 @@ func (m *defaultInterfaceMonitor) emit(defaultInterface *control.Interface, flag
|
||||
callback(defaultInterface, flags)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) RegisterMyInterface(interfaceName string) {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
m.myInterface = interfaceName
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) MyInterface() string {
|
||||
m.access.Lock()
|
||||
defer m.access.Unlock()
|
||||
return m.myInterface
|
||||
}
|
||||
|
||||
@@ -102,12 +102,11 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
return ErrNoRoute
|
||||
}
|
||||
|
||||
oldInterface := m.defaultInterface.Load()
|
||||
newInterface, err := m.interfaceFinder.ByIndex(index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", alias)
|
||||
}
|
||||
m.defaultInterface.Store(newInterface)
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,12 +3,9 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
@@ -30,10 +27,7 @@ func (r *autoRedirect) setupIPTables() error {
|
||||
}
|
||||
|
||||
func (r *autoRedirect) setupIPTablesForFamily(iptablesPath string) error {
|
||||
tableNameInput := r.tableName + "-input"
|
||||
tableNameForward := r.tableName + "-forward"
|
||||
tableNameOutput := r.tableName + "-output"
|
||||
tableNamePreRouteing := r.tableName + "-prerouting"
|
||||
redirectPort := r.redirectPort()
|
||||
// OUTPUT
|
||||
err := r.runShell(iptablesPath, "-t nat -N", tableNameOutput)
|
||||
@@ -50,184 +44,6 @@ func (r *autoRedirect) setupIPTablesForFamily(iptablesPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if runtime.GOOS == "android" {
|
||||
return nil
|
||||
}
|
||||
// INPUT
|
||||
err = r.runShell(iptablesPath, "-N", tableNameInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-A", tableNameInput,
|
||||
"-i", r.tunOptions.Name, "-j", "ACCEPT")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-A", tableNameInput,
|
||||
"-o", r.tunOptions.Name, "-j", "ACCEPT")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-I INPUT -j", tableNameInput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// FORWARD
|
||||
err = r.runShell(iptablesPath, "-N", tableNameForward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-A", tableNameForward,
|
||||
"-i", r.tunOptions.Name, "-j", "ACCEPT")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-A", tableNameForward,
|
||||
"-o", r.tunOptions.Name, "-j", "ACCEPT")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-I FORWARD -j", tableNameForward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// PREROUTING
|
||||
err = r.runShell(iptablesPath, "-t nat -N", tableNamePreRouteing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
routeAddress []netip.Prefix
|
||||
routeExcludeAddress []netip.Prefix
|
||||
)
|
||||
if iptablesPath == r.iptablesPath {
|
||||
routeAddress = r.tunOptions.Inet4RouteAddress
|
||||
routeExcludeAddress = r.tunOptions.Inet4RouteExcludeAddress
|
||||
} else {
|
||||
routeAddress = r.tunOptions.Inet6RouteAddress
|
||||
routeExcludeAddress = r.tunOptions.Inet6RouteExcludeAddress
|
||||
}
|
||||
if len(routeAddress) > 0 && (len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0) {
|
||||
return E.New("`*_route_address` is conflict with `include_interface` or `include_uid`")
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-i", r.tunOptions.Name, "-j RETURN")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, address := range routeExcludeAddress {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-d", address.String(), "-j RETURN")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, name := range r.tunOptions.ExcludeInterface {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-i", name, "-j RETURN")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, uid := range r.tunOptions.ExcludeUID {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-m owner --uid-owner", uid, "-j RETURN")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !r.tunOptions.EXP_DisableDNSHijack {
|
||||
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
|
||||
return it.Is4() == (iptablesPath == r.iptablesPath)
|
||||
})
|
||||
if !dnsServer.IsValid() {
|
||||
if iptablesPath == r.iptablesPath {
|
||||
if HasNextAddress(r.tunOptions.Inet4Address[0], 1) {
|
||||
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
|
||||
}
|
||||
} else {
|
||||
if HasNextAddress(r.tunOptions.Inet6Address[0], 1) {
|
||||
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
if dnsServer.IsValid() {
|
||||
if len(routeAddress) > 0 {
|
||||
for _, address := range routeAddress {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-d", address.String(), "-p udp --dport 53 -j DNAT --to", dnsServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
|
||||
for _, name := range r.tunOptions.IncludeInterface {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-i", name, "-p udp --dport 53 -j DNAT --to", dnsServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, uidRange := range r.tunOptions.IncludeUID {
|
||||
for uid := uidRange.Start; uid <= uidRange.End; uid++ {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-m owner --uid-owner", uid, "-p udp --dport 53 -j DNAT --to", dnsServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-p udp --dport 53 -j DNAT --to", dnsServer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing, "-m addrtype --dst-type LOCAL -j RETURN")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routeAddress) > 0 {
|
||||
for _, address := range routeAddress {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-d", address.String(), "-p tcp -j REDIRECT --to-ports", redirectPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if len(r.tunOptions.IncludeInterface) > 0 || len(r.tunOptions.IncludeUID) > 0 {
|
||||
for _, name := range r.tunOptions.IncludeInterface {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-i", name, "-p tcp -j REDIRECT --to-ports", redirectPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, uidRange := range r.tunOptions.IncludeUID {
|
||||
for uid := uidRange.Start; uid <= uidRange.End; uid++ {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-m owner --uid-owner", uid, "-p tcp -j REDIRECT --to-ports", redirectPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNamePreRouteing,
|
||||
"-p tcp -j REDIRECT --to-ports", redirectPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-t nat -I PREROUTING -j", tableNamePreRouteing)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -241,29 +57,11 @@ func (r *autoRedirect) cleanupIPTables() {
|
||||
}
|
||||
|
||||
func (r *autoRedirect) cleanupIPTablesForFamily(iptablesPath string) {
|
||||
tableNameInput := r.tableName + "-input"
|
||||
tableNameOutput := r.tableName + "-output"
|
||||
tableNameForward := r.tableName + "-forward"
|
||||
tableNamePreRouteing := r.tableName + "-prerouting"
|
||||
|
||||
_ = r.runShell(iptablesPath, "-t nat -D OUTPUT -j", tableNameOutput)
|
||||
_ = r.runShell(iptablesPath, "-t nat -F", tableNameOutput)
|
||||
_ = r.runShell(iptablesPath, "-t nat -X", tableNameOutput)
|
||||
if runtime.GOOS == "android" {
|
||||
return
|
||||
}
|
||||
|
||||
_ = r.runShell(iptablesPath, "-D INPUT -j", tableNameInput)
|
||||
_ = r.runShell(iptablesPath, "-F", tableNameInput)
|
||||
_ = r.runShell(iptablesPath, "-X", tableNameInput)
|
||||
|
||||
_ = r.runShell(iptablesPath, "-D FORWARD -j", tableNameForward)
|
||||
_ = r.runShell(iptablesPath, "-F", tableNameForward)
|
||||
_ = r.runShell(iptablesPath, "-X", tableNameForward)
|
||||
|
||||
_ = r.runShell(iptablesPath, "-t nat -D PREROUTING -j", tableNamePreRouteing)
|
||||
_ = r.runShell(iptablesPath, "-t nat -F", tableNamePreRouteing)
|
||||
_ = r.runShell(iptablesPath, "-t nat -X", tableNamePreRouteing)
|
||||
}
|
||||
|
||||
func (r *autoRedirect) runShell(commands ...any) error {
|
||||
|
||||
@@ -69,6 +69,7 @@ func (r *autoRedirect) Start() error {
|
||||
r.androidSu = true
|
||||
for _, suPath := range []string{
|
||||
"su",
|
||||
"/product/bin/su",
|
||||
"/system/bin/su",
|
||||
} {
|
||||
r.suPath, err = exec.LookPath(suPath)
|
||||
@@ -83,9 +84,8 @@ func (r *autoRedirect) Start() error {
|
||||
} else {
|
||||
if r.useNFTables {
|
||||
err = r.initializeNFTables()
|
||||
if err != nil && err != os.ErrInvalid {
|
||||
r.useNFTables = false
|
||||
r.logger.Debug("missing nftables support: ", err)
|
||||
if err != nil {
|
||||
return E.Cause(err, "missing nftables support")
|
||||
}
|
||||
}
|
||||
if len(r.tunOptions.Inet4Address) > 0 {
|
||||
|
||||
@@ -46,6 +46,11 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.nftablesCreateLoopbackAddressSets(nft, table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
|
||||
if !skipOutput {
|
||||
chainOutput := nft.AddChain(&nftables.Chain{
|
||||
@@ -61,10 +66,25 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
return err
|
||||
}
|
||||
r.nftablesCreateUnreachable(nft, table, chainOutput)
|
||||
r.nftablesCreateRedirect(nft, table, chainOutput)
|
||||
|
||||
err = r.nftablesCreateRedirect(nft, table, chainOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 {
|
||||
chainOutputRoute := nft.AddChain(&nftables.Chain{
|
||||
Name: "output_route",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeRoute,
|
||||
})
|
||||
err = r.nftablesCreateLoopbackReroute(nft, table, chainOutputRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
chainOutputUDP := nft.AddChain(&nftables.Chain{
|
||||
Name: "output_udp",
|
||||
Name: "output_udp_icmp",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
@@ -77,7 +97,7 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
r.nftablesCreateUnreachable(nft, table, chainOutputUDP)
|
||||
r.nftablesCreateMark(nft, table, chainOutputUDP)
|
||||
} else {
|
||||
r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
|
||||
err = r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
}, &expr.Cmp{
|
||||
@@ -85,6 +105,9 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,22 +123,45 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
return err
|
||||
}
|
||||
r.nftablesCreateUnreachable(nft, table, chainPreRouting)
|
||||
r.nftablesCreateRedirect(nft, table, chainPreRouting)
|
||||
r.nftablesCreateMark(nft, table, chainPreRouting)
|
||||
|
||||
err = r.nftablesCreateRedirect(nft, table, chainPreRouting)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.tunOptions.AutoRedirectMarkMode {
|
||||
r.nftablesCreateMark(nft, table, chainPreRouting)
|
||||
if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 {
|
||||
chainPreRoutingFilter := nft.AddChain(&nftables.Chain{
|
||||
Name: "prerouting_filter",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 1),
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
err = r.nftablesCreateLoopbackReroute(nft, table, chainPreRoutingFilter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
|
||||
Name: "prerouting_udp",
|
||||
Name: "prerouting_udp_icmp",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
if r.enableIPv4 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chainPreRoutingUDP, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false)
|
||||
ipProto := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetProto,
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chainPreRoutingUDP, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, false)
|
||||
err = nft.AddSet(ipProto, []nftables.SetElement{
|
||||
{Key: []byte{unix.IPPROTO_UDP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMPV6}},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
@@ -125,10 +171,48 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: ipProto.ID,
|
||||
SetName: ipProto.Name,
|
||||
Invert: true,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_UDP},
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
@@ -147,6 +231,40 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
&expr.Counter{},
|
||||
},
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
err = r.configureOpenWRTFirewall4(nft, false)
|
||||
@@ -224,7 +342,7 @@ func (r *autoRedirect) cleanupNFTables() {
|
||||
Name: r.tableName,
|
||||
Family: nftables.TableFamilyINet,
|
||||
})
|
||||
common.Must(r.configureOpenWRTFirewall4(nft, true))
|
||||
_ = r.configureOpenWRTFirewall4(nft, true)
|
||||
_ = nft.Flush()
|
||||
_ = nft.CloseLasting()
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/nftables/expr"
|
||||
"github.com/sagernet/sing/common"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
@@ -21,6 +22,20 @@ func nftablesCreateExcludeDestinationIPSet(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
id uint32, name string, family nftables.TableFamily, invert bool,
|
||||
) {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: append(
|
||||
nftablesCreateDestinationIPSetExprs(id, name, family, invert),
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
func nftablesCreateDestinationIPSetExprs(id uint32, name string, family nftables.TableFamily, invert bool) []expr.Any {
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
@@ -53,22 +68,63 @@ func nftablesCreateExcludeDestinationIPSet(
|
||||
},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: id,
|
||||
SetName: name,
|
||||
Invert: invert,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
exprs = append(exprs, &expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: id,
|
||||
SetName: name,
|
||||
Invert: invert,
|
||||
})
|
||||
return exprs
|
||||
}
|
||||
|
||||
func nftablesCreateIPConst(
|
||||
nft *nftables.Conn, table *nftables.Table, id uint32, name string, family nftables.TableFamily, addressList []netip.Addr,
|
||||
) (*nftables.Set, error) {
|
||||
var keyType nftables.SetDatatype
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
keyType = nftables.TypeIPAddr
|
||||
} else {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
mySet := &nftables.Set{
|
||||
Table: table,
|
||||
ID: id,
|
||||
Name: name,
|
||||
KeyType: keyType,
|
||||
Constant: true,
|
||||
}
|
||||
if id == 0 {
|
||||
mySet.Anonymous = true
|
||||
}
|
||||
setElements := common.Map(addressList, func(addr netip.Addr) nftables.SetElement { return nftables.SetElement{Key: addr.AsSlice()} })
|
||||
if id == 0 {
|
||||
err := nft.AddSet(mySet, setElements)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mySet, nil
|
||||
} else {
|
||||
err := nft.AddSet(mySet, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for len(setElements) > 0 {
|
||||
toAdd := setElements
|
||||
if len(toAdd) > 1000 {
|
||||
toAdd = toAdd[:1000]
|
||||
}
|
||||
setElements = setElements[len(toAdd):]
|
||||
err := nft.SetAddElements(mySet, toAdd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = nft.Flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return mySet, nil
|
||||
}
|
||||
|
||||
func nftablesCreateIPSet(
|
||||
|
||||
@@ -103,10 +103,6 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets(
|
||||
update = true
|
||||
}
|
||||
}
|
||||
localAddresses6 = common.Filter(localAddresses6, func(it netip.Prefix) bool {
|
||||
address := it.Addr()
|
||||
return address.IsLoopback() || address.IsGlobalUnicast() && !address.IsPrivate()
|
||||
})
|
||||
if len(lastAddresses) == 0 || update {
|
||||
_, err := nftablesCreateIPSet(nft, table, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, nil, localAddresses6, false, update)
|
||||
if err != nil {
|
||||
@@ -117,8 +113,61 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateLoopbackAddressSets(
|
||||
nft *nftables.Conn, table *nftables.Table,
|
||||
) error {
|
||||
if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 {
|
||||
_, err := nftablesCreateIPConst(nft, table, 7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, r.tunOptions.Inet4LoopbackAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 {
|
||||
_, err := nftablesCreateIPConst(nft, table, 8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, r.tunOptions.Inet6LoopbackAddress)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error {
|
||||
if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput {
|
||||
if chain.Type == nftables.ChainTypeRoute {
|
||||
ipProto := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetProto,
|
||||
}
|
||||
err := nft.AddSet(ipProto, []nftables.SetElement{
|
||||
{Key: []byte{unix.IPPROTO_UDP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMPV6}},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: ipProto.ID,
|
||||
SetName: ipProto.Name,
|
||||
Invert: true,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
@@ -138,8 +187,48 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
|
||||
},
|
||||
},
|
||||
})
|
||||
if chain.Type == nftables.ChainTypeRoute {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if chain.Hooknum == nftables.ChainHookPrerouting {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
if len(r.tunOptions.IncludeInterface) > 0 {
|
||||
if len(r.tunOptions.IncludeInterface) > 1 {
|
||||
includeInterface := &nftables.Set{
|
||||
@@ -402,42 +491,19 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
|
||||
if !r.tunOptions.EXP_DisableDNSHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) ||
|
||||
(r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT)) {
|
||||
if r.enableIPv4 {
|
||||
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv4)
|
||||
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv4, 5, "inet4_local_address_set")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv6)
|
||||
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv6, 6, "inet6_local_address_set")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if r.tunOptions.AutoRedirectMarkMode &&
|
||||
((chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeRoute) ||
|
||||
(chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeFilter)) {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_UDP},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.enableIPv4 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false)
|
||||
}
|
||||
@@ -491,6 +557,9 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -498,62 +567,198 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta
|
||||
func (r *autoRedirect) nftablesCreateRedirect(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
exprs ...expr.Any,
|
||||
) {
|
||||
if r.enableIPv4 && !r.enableIPv6 {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv4)},
|
||||
})
|
||||
} else if !r.enableIPv4 && r.enableIPv6 {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv6)},
|
||||
})
|
||||
) error {
|
||||
exprsRedirect := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
|
||||
},
|
||||
&expr.Redir{
|
||||
RegisterProtoMin: 1,
|
||||
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Counter{},
|
||||
if len(r.tunOptions.Inet4LoopbackAddress) == 0 && len(r.tunOptions.Inet6LoopbackAddress) == 0 {
|
||||
if r.enableIPv4 && !r.enableIPv6 {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv4)},
|
||||
})
|
||||
} else if !r.enableIPv4 && r.enableIPv6 {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv6)},
|
||||
})
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs, exprsRedirect...),
|
||||
})
|
||||
} else {
|
||||
if r.enableIPv4 {
|
||||
exprs4 := exprs
|
||||
if len(r.tunOptions.Inet4LoopbackAddress) > 0 {
|
||||
exprs4 = append(exprs4, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, true)...)
|
||||
} else {
|
||||
exprs4 = append(exprs4, &expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
}, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv4)},
|
||||
})
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs4, exprsRedirect...),
|
||||
})
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
exprs6 := exprs
|
||||
if len(r.tunOptions.Inet6LoopbackAddress) > 0 {
|
||||
exprs6 = append(exprs6, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, true)...)
|
||||
} else {
|
||||
exprs6 = append(exprs6, &expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
}, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv6)},
|
||||
})
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs6, exprsRedirect...),
|
||||
})
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateLoopbackReroute(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
) error {
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
}
|
||||
var exprs4 []expr.Any
|
||||
if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 {
|
||||
exprs4 = append(exprs, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, false)...)
|
||||
}
|
||||
var exprs6 []expr.Any
|
||||
if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 {
|
||||
exprs6 = append(exprs, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, false)...)
|
||||
}
|
||||
var exprsCreateMark []expr.Any
|
||||
if chain.Hooknum == nftables.ChainHookPrerouting {
|
||||
exprsCreateMark = []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
&expr.Redir{
|
||||
RegisterProtoMin: 1,
|
||||
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
&expr.Counter{},
|
||||
}
|
||||
} else {
|
||||
exprsCreateMark = []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
),
|
||||
})
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
}
|
||||
}
|
||||
if len(exprs4) > 0 {
|
||||
exprs4 = append(exprs4, exprsCreateMark...)
|
||||
}
|
||||
if len(exprs6) > 0 {
|
||||
exprs6 = append(exprs6, exprsCreateMark...)
|
||||
}
|
||||
if len(exprs4) > 0 {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: exprs4,
|
||||
})
|
||||
}
|
||||
if len(exprs6) > 0 {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: exprs6,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
family nftables.TableFamily,
|
||||
family nftables.TableFamily, setID uint32, setName string,
|
||||
) error {
|
||||
ipProto := &nftables.Set{
|
||||
Table: table,
|
||||
@@ -611,6 +816,33 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
|
||||
Data: nftablesIfname("lo"),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs, &expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: setID,
|
||||
SetName: setName,
|
||||
})
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
@@ -634,6 +866,7 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(53),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: dnsServer.AsSlice(),
|
||||
@@ -643,7 +876,6 @@ func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
|
||||
Family: uint32(family),
|
||||
RegAddrMin: 1,
|
||||
},
|
||||
&expr.Counter{},
|
||||
)
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
@@ -679,9 +911,7 @@ func (r *autoRedirect) nftablesCreateUnreachable(
|
||||
Data: []byte{uint8(nfProto)},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
&expr.Reject{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,101 +3,50 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/nftables/expr"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"github.com/sagernet/nftables"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/shell"
|
||||
)
|
||||
|
||||
func (r *autoRedirect) configureOpenWRTFirewall4(nft *nftables.Conn, cleanup bool) error {
|
||||
tableFW4, err := nft.ListTableOfFamily("fw4", nftables.TableFamilyINet)
|
||||
_, err := nft.ListTableOfFamily("fw4", nftables.TableFamilyINet)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !cleanup {
|
||||
ruleIif := &nftables.Rule{
|
||||
Table: tableFW4,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
ruleOif := &nftables.Rule{
|
||||
Table: tableFW4,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
chainForward := &nftables.Chain{
|
||||
Name: "forward",
|
||||
}
|
||||
ruleIif.Chain = chainForward
|
||||
ruleOif.Chain = chainForward
|
||||
nft.InsertRule(ruleOif)
|
||||
nft.InsertRule(ruleIif)
|
||||
chainInput := &nftables.Chain{
|
||||
Name: "input",
|
||||
}
|
||||
ruleIif.Chain = chainInput
|
||||
ruleOif.Chain = chainInput
|
||||
nft.InsertRule(ruleOif)
|
||||
nft.InsertRule(ruleIif)
|
||||
fw4Path, err := exec.LookPath("fw4")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
for _, chainName := range []string{"input", "forward"} {
|
||||
var rules []*nftables.Rule
|
||||
rules, err = nft.GetRules(tableFW4, &nftables.Chain{
|
||||
Name: chainName,
|
||||
})
|
||||
rulePath := "/etc/nftables.d/0-" + r.tableName + "-auto-redirect.nft"
|
||||
if !cleanup {
|
||||
err = os.WriteFile(rulePath, []byte(`chain input {
|
||||
type filter hook input priority filter; policy accept;
|
||||
iifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
|
||||
oifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
|
||||
}
|
||||
chain forward {
|
||||
type filter hook forward priority filter; policy accept;
|
||||
iifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
|
||||
oifname "`+r.tunOptions.Name+`" counter accept comment "!`+r.tableName+`: Accept traffic from tun"
|
||||
}
|
||||
`), 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
return E.Cause(err, "write fw4 rules")
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.Exprs) != 4 {
|
||||
continue
|
||||
}
|
||||
exprMeta, isMeta := rule.Exprs[0].(*expr.Meta)
|
||||
if !isMeta {
|
||||
continue
|
||||
}
|
||||
if exprMeta.Key != expr.MetaKeyIIFNAME && exprMeta.Key != expr.MetaKeyOIFNAME {
|
||||
continue
|
||||
}
|
||||
exprCmp, isCmp := rule.Exprs[1].(*expr.Cmp)
|
||||
if !isCmp {
|
||||
continue
|
||||
}
|
||||
if !slices.Equal(exprCmp.Data, nftablesIfname(r.tunOptions.Name)) {
|
||||
continue
|
||||
}
|
||||
err = nft.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if _, err = os.Stat(rulePath); os.IsNotExist(err) {
|
||||
return nil
|
||||
} else {
|
||||
err = os.Remove(rulePath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "clean fw4 rules")
|
||||
}
|
||||
}
|
||||
output, err := shell.Exec(fw4Path, "reload").Read()
|
||||
if err != nil {
|
||||
return E.Extend(E.Cause(err, "reload fw4 rules"), output)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
@@ -26,19 +26,22 @@ const WithGVisor = true
|
||||
const DefaultNIC tcpip.NICID = 1
|
||||
|
||||
type GVisor struct {
|
||||
ctx context.Context
|
||||
tun GVisorTun
|
||||
udpTimeout time.Duration
|
||||
broadcastAddr netip.Addr
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
stack *stack.Stack
|
||||
endpoint stack.LinkEndpoint
|
||||
ctx context.Context
|
||||
tun GVisorTun
|
||||
inet4LoopbackAddress []netip.Addr
|
||||
inet6LoopbackAddress []netip.Addr
|
||||
udpTimeout time.Duration
|
||||
broadcastAddr netip.Addr
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
stack *stack.Stack
|
||||
endpoint stack.LinkEndpoint
|
||||
}
|
||||
|
||||
type GVisorTun interface {
|
||||
Tun
|
||||
NewEndpoint() (stack.LinkEndpoint, error)
|
||||
WritePacket(pkt *stack.PacketBuffer) (int, error)
|
||||
NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error)
|
||||
}
|
||||
|
||||
func NewGVisor(
|
||||
@@ -50,27 +53,29 @@ func NewGVisor(
|
||||
}
|
||||
|
||||
gStack := &GVisor{
|
||||
ctx: options.Context,
|
||||
tun: gTun,
|
||||
udpTimeout: options.UDPTimeout,
|
||||
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
ctx: options.Context,
|
||||
tun: gTun,
|
||||
inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
|
||||
inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
|
||||
udpTimeout: options.UDPTimeout,
|
||||
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
}
|
||||
return gStack, nil
|
||||
}
|
||||
|
||||
func (t *GVisor) Start() error {
|
||||
linkEndpoint, err := t.tun.NewEndpoint()
|
||||
linkEndpoint, nicOptions, err := t.tun.NewEndpoint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
|
||||
ipStack, err := NewGVisorStack(linkEndpoint)
|
||||
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
|
||||
t.stack = ipStack
|
||||
t.endpoint = linkEndpoint
|
||||
@@ -106,6 +111,10 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
|
||||
}
|
||||
|
||||
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
||||
return NewGVisorStackWithOptions(ep, stack.NICOptions{})
|
||||
}
|
||||
|
||||
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*stack.Stack, error) {
|
||||
ipStack := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
@@ -118,7 +127,7 @@ func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
})
|
||||
err := ipStack.CreateNIC(DefaultNIC, ep)
|
||||
err := ipStack.CreateNICWithOptions(DefaultNIC, ep, opts)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
}
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
|
||||
@@ -17,7 +15,7 @@ var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
|
||||
type LinkEndpointFilter struct {
|
||||
stack.LinkEndpoint
|
||||
BroadcastAddress netip.Addr
|
||||
Writer N.VectorisedWriter
|
||||
Writer GVisorTun
|
||||
}
|
||||
|
||||
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -29,7 +27,7 @@ var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
|
||||
type networkDispatcherFilter struct {
|
||||
stack.NetworkDispatcher
|
||||
broadcastAddress netip.Addr
|
||||
writer N.VectorisedWriter
|
||||
writer GVisorTun
|
||||
}
|
||||
|
||||
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
|
||||
@@ -49,7 +47,7 @@ func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkPro
|
||||
}
|
||||
destination := AddrFromAddress(network.DestinationAddress())
|
||||
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
|
||||
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
|
||||
w.writer.WritePacket(pkt)
|
||||
return
|
||||
}
|
||||
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
|
||||
|
||||
@@ -4,8 +4,10 @@ package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
@@ -17,19 +19,25 @@ import (
|
||||
)
|
||||
|
||||
type gLazyConn struct {
|
||||
tcpConn *gonet.TCPConn
|
||||
parentCtx context.Context
|
||||
stack *stack.Stack
|
||||
request *tcp.ForwarderRequest
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
handshakeDone bool
|
||||
handshakeErr error
|
||||
tcpConn *gonet.TCPConn
|
||||
parentCtx context.Context
|
||||
stack *stack.Stack
|
||||
request *tcp.ForwarderRequest
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
handshakeAccess sync.Mutex
|
||||
handshakeDone bool
|
||||
handshakeErr error
|
||||
}
|
||||
|
||||
func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
|
||||
if c.handshakeDone {
|
||||
return nil
|
||||
return c.handshakeErr
|
||||
}
|
||||
c.handshakeAccess.Lock()
|
||||
defer c.handshakeAccess.Unlock()
|
||||
if c.handshakeDone {
|
||||
return c.handshakeErr
|
||||
}
|
||||
defer func() {
|
||||
c.handshakeDone = true
|
||||
@@ -67,7 +75,12 @@ func (c *gLazyConn) HandshakeFailure(err error) error {
|
||||
if c.handshakeDone {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
c.request.Complete(err != ErrDrop)
|
||||
c.handshakeAccess.Lock()
|
||||
defer c.handshakeAccess.Unlock()
|
||||
if c.handshakeDone {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
c.request.Complete(!errors.Is(err, ErrDrop))
|
||||
c.handshakeDone = true
|
||||
c.handshakeErr = err
|
||||
return nil
|
||||
@@ -78,25 +91,17 @@ func (c *gLazyConn) HandshakeSuccess() error {
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Read(b []byte) (n int, err error) {
|
||||
if !c.handshakeDone {
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return 0, c.handshakeErr
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.tcpConn.Read(b)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Write(b []byte) (n int, err error) {
|
||||
if !c.handshakeDone {
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return 0, c.handshakeErr
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return c.tcpConn.Write(b)
|
||||
}
|
||||
@@ -110,46 +115,41 @@ func (c *gLazyConn) RemoteAddr() net.Addr {
|
||||
}
|
||||
|
||||
func (c *gLazyConn) SetDeadline(t time.Time) error {
|
||||
if !c.handshakeDone {
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return c.handshakeErr
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.tcpConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) SetReadDeadline(t time.Time) error {
|
||||
if !c.handshakeDone {
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return c.handshakeErr
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.tcpConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
|
||||
if !c.handshakeDone {
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return c.handshakeErr
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.tcpConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Close() error {
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
return nil
|
||||
c.handshakeAccess.Lock()
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
c.handshakeDone = true
|
||||
return nil
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
c.handshakeAccess.Unlock()
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -158,9 +158,16 @@ func (c *gLazyConn) Close() error {
|
||||
|
||||
func (c *gLazyConn) CloseRead() error {
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
return nil
|
||||
c.handshakeAccess.Lock()
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
c.handshakeDone = true
|
||||
return nil
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
c.handshakeAccess.Unlock()
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -169,9 +176,16 @@ func (c *gLazyConn) CloseRead() error {
|
||||
|
||||
func (c *gLazyConn) CloseWrite() error {
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
return nil
|
||||
c.handshakeAccess.Lock()
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
c.handshakeDone = true
|
||||
return nil
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
c.handshakeAccess.Unlock()
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -179,10 +193,14 @@ func (c *gLazyConn) CloseWrite() error {
|
||||
}
|
||||
|
||||
func (c *gLazyConn) ReaderReplaceable() bool {
|
||||
c.handshakeAccess.Lock()
|
||||
defer c.handshakeAccess.Unlock()
|
||||
return c.handshakeDone && c.handshakeErr == nil
|
||||
}
|
||||
|
||||
func (c *gLazyConn) WriterReplaceable() bool {
|
||||
c.handshakeAccess.Lock()
|
||||
defer c.handshakeAccess.Unlock()
|
||||
return c.handshakeDone && c.handshakeErr == nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,31 +4,75 @@ package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TCPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
handler Handler
|
||||
forwarder *tcp.Forwarder
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
handler Handler
|
||||
inet4LoopbackAddress []tcpip.Address
|
||||
inet6LoopbackAddress []tcpip.Address
|
||||
tun GVisorTun
|
||||
forwarder *tcp.Forwarder
|
||||
}
|
||||
|
||||
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
|
||||
return NewTCPForwarderWithLoopback(ctx, stack, handler, nil, nil, nil)
|
||||
}
|
||||
|
||||
func NewTCPForwarderWithLoopback(ctx context.Context, stack *stack.Stack, handler Handler, inet4LoopbackAddress []netip.Addr, inet6LoopbackAddress []netip.Addr, tun GVisorTun) *TCPForwarder {
|
||||
forwarder := &TCPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
handler: handler,
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
handler: handler,
|
||||
inet4LoopbackAddress: common.Map(inet4LoopbackAddress, AddressFromAddr),
|
||||
inet6LoopbackAddress: common.Map(inet6LoopbackAddress, AddressFromAddr),
|
||||
tun: tun,
|
||||
}
|
||||
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
|
||||
return forwarder
|
||||
}
|
||||
|
||||
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
for _, inet4LoopbackAddress := range f.inet4LoopbackAddress {
|
||||
if id.LocalAddress == inet4LoopbackAddress {
|
||||
ipHdr := pkt.Network().(header.IPv4)
|
||||
ipHdr.SetDestinationAddressWithChecksumUpdate(ipHdr.SourceAddress())
|
||||
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
|
||||
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
f.tun.WritePacket(pkt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, inet6LoopbackAddress := range f.inet6LoopbackAddress {
|
||||
if id.LocalAddress == inet6LoopbackAddress {
|
||||
ipHdr := pkt.Network().(header.IPv6)
|
||||
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
|
||||
ipHdr.SetSourceAddress(inet6LoopbackAddress)
|
||||
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
f.tun.WritePacket(pkt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return f.forwarder.HandlePacket(id, pkt)
|
||||
}
|
||||
|
||||
@@ -37,7 +81,7 @@ func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
|
||||
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
|
||||
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
|
||||
if pErr != nil {
|
||||
r.Complete(pErr != ErrDrop)
|
||||
r.Complete(!errors.Is(pErr, ErrDrop))
|
||||
return
|
||||
}
|
||||
conn := &gLazyConn{
|
||||
|
||||
@@ -4,6 +4,7 @@ package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -59,7 +60,7 @@ func rangeIterate(r stack.Range, fn func(*buffer.View))
|
||||
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
|
||||
if pErr != nil {
|
||||
if pErr != ErrDrop {
|
||||
if !errors.Is(pErr, ErrDrop) {
|
||||
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))
|
||||
}
|
||||
return false, nil, nil, nil
|
||||
|
||||
@@ -10,12 +10,13 @@ import (
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type Mixed struct {
|
||||
*System
|
||||
tun GVisorTun
|
||||
stack *stack.Stack
|
||||
endpoint *channel.Endpoint
|
||||
}
|
||||
@@ -29,6 +30,7 @@ func NewMixed(
|
||||
}
|
||||
return &Mixed{
|
||||
System: system.(*System),
|
||||
tun: system.(*System).tun.(GVisorTun),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -72,10 +74,14 @@ func (m *Mixed) tunLoop() {
|
||||
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
|
||||
batchSize := linuxTUN.BatchSize()
|
||||
if batchSize > 1 {
|
||||
m.batchLoop(linuxTUN, batchSize)
|
||||
m.batchLoopLinux(linuxTUN, batchSize)
|
||||
return
|
||||
}
|
||||
}
|
||||
if darwinTUN, isDarwinTUN := m.tun.(DarwinTUN); isDarwinTUN && m.multiPendingPackets {
|
||||
m.batchLoopDarwin(darwinTUN)
|
||||
return
|
||||
}
|
||||
packetBuffer := make([]byte, m.mtu+PacketOffset)
|
||||
for {
|
||||
n, err := m.tun.Read(packetBuffer)
|
||||
@@ -119,12 +125,12 @@ func (m *Mixed) wintunLoop(winTun WinTun) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||
func (m *Mixed) batchLoopLinux(linuxTUN LinuxTUN, batchSize int) {
|
||||
packetBuffers := make([][]byte, batchSize)
|
||||
writeBuffers := make([][]byte, batchSize)
|
||||
packetSizes := make([]int, batchSize)
|
||||
for i := range packetBuffers {
|
||||
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
|
||||
packetBuffers[i] = make([]byte, m.mtu+PacketOffset+m.frontHeadroom)
|
||||
}
|
||||
for {
|
||||
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
|
||||
@@ -158,6 +164,40 @@ func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mixed) batchLoopDarwin(darwinTUN DarwinTUN) {
|
||||
var writeBuffers []*buf.Buffer
|
||||
for {
|
||||
buffers, err := darwinTUN.BatchRead()
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
m.logger.Error(E.Cause(err, "batch read packet"))
|
||||
}
|
||||
if len(buffers) == 0 {
|
||||
continue
|
||||
}
|
||||
writeBuffers = writeBuffers[:0]
|
||||
for _, buffer := range buffers {
|
||||
packetSize := buffer.Len()
|
||||
if packetSize < header.IPv4MinimumSize {
|
||||
continue
|
||||
}
|
||||
if m.processPacket(buffer.Bytes()) {
|
||||
writeBuffers = append(writeBuffers, buffer)
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
if len(writeBuffers) > 0 {
|
||||
err = darwinTUN.BatchWrite(writeBuffers)
|
||||
if err != nil {
|
||||
m.logger.Trace(E.Cause(err, "batch write packet"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mixed) processPacket(packet []byte) bool {
|
||||
var (
|
||||
writeBack bool
|
||||
@@ -226,11 +266,11 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
|
||||
|
||||
func (m *Mixed) packetLoop() {
|
||||
for {
|
||||
packet := m.endpoint.ReadContext(m.ctx)
|
||||
if packet == nil {
|
||||
pkt := m.endpoint.ReadContext(m.ctx)
|
||||
if pkt == nil {
|
||||
break
|
||||
}
|
||||
bufio.WriteVectorised(m.tun, packet.AsSlices())
|
||||
packet.DecRef()
|
||||
m.tun.WritePacket(pkt)
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
203
stack_system.go
203
stack_system.go
@@ -2,6 +2,7 @@ package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
@@ -22,30 +23,33 @@ import (
|
||||
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
|
||||
|
||||
type System struct {
|
||||
ctx context.Context
|
||||
tun Tun
|
||||
tunName string
|
||||
mtu int
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
inet4Prefixes []netip.Prefix
|
||||
inet6Prefixes []netip.Prefix
|
||||
inet4ServerAddress netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6ServerAddress netip.Addr
|
||||
inet6Address netip.Addr
|
||||
broadcastAddr netip.Addr
|
||||
udpTimeout time.Duration
|
||||
tcpListener net.Listener
|
||||
tcpListener6 net.Listener
|
||||
tcpPort uint16
|
||||
tcpPort6 uint16
|
||||
tcpNat *TCPNat
|
||||
udpNat *udpnat.Service
|
||||
bindInterface bool
|
||||
interfaceFinder control.InterfaceFinder
|
||||
frontHeadroom int
|
||||
txChecksumOffload bool
|
||||
ctx context.Context
|
||||
tun Tun
|
||||
tunName string
|
||||
mtu int
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
inet4Prefixes []netip.Prefix
|
||||
inet6Prefixes []netip.Prefix
|
||||
inet4ServerAddress netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6ServerAddress netip.Addr
|
||||
inet6Address netip.Addr
|
||||
broadcastAddr netip.Addr
|
||||
inet4LoopbackAddress []netip.Addr
|
||||
inet6LoopbackAddress []netip.Addr
|
||||
udpTimeout time.Duration
|
||||
tcpListener net.Listener
|
||||
tcpListener6 net.Listener
|
||||
tcpPort uint16
|
||||
tcpPort6 uint16
|
||||
tcpNat *TCPNat
|
||||
udpNat *udpnat.Service
|
||||
bindInterface bool
|
||||
interfaceFinder control.InterfaceFinder
|
||||
frontHeadroom int
|
||||
txChecksumOffload bool
|
||||
multiPendingPackets bool
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
@@ -57,18 +61,21 @@ type Session struct {
|
||||
|
||||
func NewSystem(options StackOptions) (Stack, error) {
|
||||
stack := &System{
|
||||
ctx: options.Context,
|
||||
tun: options.Tun,
|
||||
tunName: options.TunOptions.Name,
|
||||
mtu: int(options.TunOptions.MTU),
|
||||
udpTimeout: options.UDPTimeout,
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
inet4Prefixes: options.TunOptions.Inet4Address,
|
||||
inet6Prefixes: options.TunOptions.Inet6Address,
|
||||
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
|
||||
bindInterface: options.ForwarderBindInterface,
|
||||
interfaceFinder: options.InterfaceFinder,
|
||||
ctx: options.Context,
|
||||
tun: options.Tun,
|
||||
tunName: options.TunOptions.Name,
|
||||
mtu: int(options.TunOptions.MTU),
|
||||
inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
|
||||
inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
|
||||
udpTimeout: options.UDPTimeout,
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
inet4Prefixes: options.TunOptions.Inet4Address,
|
||||
inet6Prefixes: options.TunOptions.Inet6Address,
|
||||
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
|
||||
bindInterface: options.ForwarderBindInterface,
|
||||
interfaceFinder: options.InterfaceFinder,
|
||||
multiPendingPackets: options.TunOptions.EXP_MultiPendingPackets,
|
||||
}
|
||||
if len(options.TunOptions.Inet4Address) > 0 {
|
||||
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
|
||||
@@ -107,10 +114,7 @@ func (s *System) Start() error {
|
||||
}
|
||||
|
||||
func (s *System) start() error {
|
||||
err := fixWindowsFirewall()
|
||||
if err != nil {
|
||||
return E.Cause(err, "fix windows firewall for system stack")
|
||||
}
|
||||
_ = fixWindowsFirewall()
|
||||
var listener net.ListenConfig
|
||||
if s.bindInterface {
|
||||
listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
|
||||
@@ -122,6 +126,7 @@ func (s *System) start() error {
|
||||
})
|
||||
}
|
||||
var tcpListener net.Listener
|
||||
var err error
|
||||
if s.inet4Address.IsValid() {
|
||||
for i := 0; i < 3; i++ {
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
|
||||
@@ -167,10 +172,14 @@ func (s *System) tunLoop() {
|
||||
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
|
||||
batchSize := linuxTUN.BatchSize()
|
||||
if batchSize > 1 {
|
||||
s.batchLoop(linuxTUN, batchSize)
|
||||
s.batchLoopLinux(linuxTUN, batchSize)
|
||||
return
|
||||
}
|
||||
}
|
||||
if darwinTUN, isDarwinTUN := s.tun.(DarwinTUN); isDarwinTUN && s.multiPendingPackets {
|
||||
s.batchLoopDarwin(darwinTUN)
|
||||
return
|
||||
}
|
||||
packetBuffer := make([]byte, s.mtu+PacketOffset)
|
||||
for {
|
||||
n, err := s.tun.Read(packetBuffer)
|
||||
@@ -214,7 +223,7 @@ func (s *System) wintunLoop(winTun WinTun) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||
func (s *System) batchLoopLinux(linuxTUN LinuxTUN, batchSize int) {
|
||||
packetBuffers := make([][]byte, batchSize)
|
||||
writeBuffers := make([][]byte, batchSize)
|
||||
packetSizes := make([]int, batchSize)
|
||||
@@ -253,6 +262,40 @@ func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) batchLoopDarwin(darwinTUN DarwinTUN) {
|
||||
var writeBuffers []*buf.Buffer
|
||||
for {
|
||||
buffers, err := darwinTUN.BatchRead()
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
s.logger.Error(E.Cause(err, "batch read packet"))
|
||||
}
|
||||
if len(buffers) == 0 {
|
||||
continue
|
||||
}
|
||||
writeBuffers = writeBuffers[:0]
|
||||
for _, buffer := range buffers {
|
||||
packetSize := buffer.Len()
|
||||
if packetSize < header.IPv4MinimumSize {
|
||||
continue
|
||||
}
|
||||
if s.processPacket(buffer.Bytes()) {
|
||||
writeBuffers = append(writeBuffers, buffer)
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
if len(writeBuffers) > 0 {
|
||||
err = darwinTUN.BatchWrite(writeBuffers)
|
||||
if err != nil {
|
||||
s.logger.Trace(E.Cause(err, "batch write packet"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) processPacket(packet []byte) bool {
|
||||
var (
|
||||
writeBack bool
|
||||
@@ -354,28 +397,37 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
|
||||
ipHdr.SetDestinationAddr(session.Source.Addr())
|
||||
tcpHdr.SetDestinationPort(session.Source.Port())
|
||||
} else {
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
if err == ErrDrop {
|
||||
return false, nil
|
||||
} else {
|
||||
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
|
||||
var loopback bool
|
||||
for _, inet4LoopbackAddress := range s.inet4LoopbackAddress {
|
||||
if destination.Addr() == inet4LoopbackAddress {
|
||||
ipHdr.SetDestinationAddr(ipHdr.SourceAddr())
|
||||
ipHdr.SetSourceAddr(inet4LoopbackAddress)
|
||||
loopback = true
|
||||
break
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet4Address)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort)
|
||||
if !loopback {
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrDrop) {
|
||||
return false, nil
|
||||
} else {
|
||||
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet4Address)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort)
|
||||
}
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
tcpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return true, nil
|
||||
}
|
||||
@@ -416,7 +468,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
@@ -441,21 +492,31 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
|
||||
ipHdr.SetDestinationAddr(session.Source.Addr())
|
||||
tcpHdr.SetDestinationPort(session.Source.Port())
|
||||
} else {
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
if err == ErrDrop {
|
||||
return false, nil
|
||||
} else {
|
||||
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
|
||||
var loopback bool
|
||||
for _, inet6LoopbackAddress := range s.inet6LoopbackAddress {
|
||||
if destination.Addr() == inet6LoopbackAddress {
|
||||
ipHdr.SetDestinationAddr(ipHdr.SourceAddr())
|
||||
ipHdr.SetSourceAddr(inet6LoopbackAddress)
|
||||
loopback = true
|
||||
break
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet6Address)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort6)
|
||||
if !loopback {
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrDrop) {
|
||||
return false, nil
|
||||
} else {
|
||||
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet6Address)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort6)
|
||||
}
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
@@ -538,7 +599,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
|
||||
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
|
||||
if pErr != nil {
|
||||
if pErr != ErrDrop {
|
||||
if !errors.Is(pErr, ErrDrop) {
|
||||
if source.IsIPv4() {
|
||||
ipHdr := userData.(header.IPv4)
|
||||
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
|
||||
@@ -586,8 +647,7 @@ func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error
|
||||
sourceAddress := ipHdr.SourceAddr()
|
||||
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
||||
ipHdr.SetDestinationAddr(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
ipHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return nil
|
||||
}
|
||||
@@ -621,7 +681,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
|
||||
icmpHdr := header.ICMPv4(newIPHdr.Payload())
|
||||
icmpHdr.SetType(header.ICMPv4DstUnreachable)
|
||||
icmpHdr.SetCode(code)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
copy(icmpHdr.Payload(), payload)
|
||||
if PacketOffset > 0 {
|
||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
|
||||
@@ -714,14 +774,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
udpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
@@ -755,7 +813,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(udpLen)
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type TCPNat struct {
|
||||
timeout time.Duration
|
||||
portIndex uint16
|
||||
portAccess sync.RWMutex
|
||||
addrAccess sync.RWMutex
|
||||
@@ -19,6 +20,7 @@ type TCPNat struct {
|
||||
}
|
||||
|
||||
type TCPSession struct {
|
||||
sync.Mutex
|
||||
Source netip.AddrPort
|
||||
Destination netip.AddrPort
|
||||
LastActive time.Time
|
||||
@@ -26,38 +28,41 @@ type TCPSession struct {
|
||||
|
||||
func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
|
||||
natMap := &TCPNat{
|
||||
timeout: timeout,
|
||||
portIndex: 10000,
|
||||
addrMap: make(map[netip.AddrPort]uint16),
|
||||
portMap: make(map[uint16]*TCPSession),
|
||||
}
|
||||
go natMap.loopCheckTimeout(ctx, timeout)
|
||||
go natMap.loopCheckTimeout(ctx)
|
||||
return natMap
|
||||
}
|
||||
|
||||
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
|
||||
ticker := time.NewTicker(timeout)
|
||||
func (n *TCPNat) loopCheckTimeout(ctx context.Context) {
|
||||
ticker := time.NewTicker(n.timeout)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
n.checkTimeout(timeout)
|
||||
n.checkTimeout()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TCPNat) checkTimeout(timeout time.Duration) {
|
||||
func (n *TCPNat) checkTimeout() {
|
||||
now := time.Now()
|
||||
n.portAccess.Lock()
|
||||
defer n.portAccess.Unlock()
|
||||
n.addrAccess.Lock()
|
||||
defer n.addrAccess.Unlock()
|
||||
for natPort, session := range n.portMap {
|
||||
if now.Sub(session.LastActive) > timeout {
|
||||
session.Lock()
|
||||
if now.Sub(session.LastActive) > n.timeout {
|
||||
delete(n.addrMap, session.Source)
|
||||
delete(n.portMap, natPort)
|
||||
}
|
||||
session.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
|
||||
session := n.portMap[port]
|
||||
n.portAccess.RUnlock()
|
||||
if session != nil {
|
||||
session.LastActive = time.Now()
|
||||
session.Lock()
|
||||
if time.Since(session.LastActive) > time.Second {
|
||||
session.LastActive = time.Now()
|
||||
}
|
||||
session.Unlock()
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ func fixWindowsFirewall() error {
|
||||
Protocol: winfw.NET_FW_IP_PROTOCOL_TCP,
|
||||
Direction: winfw.NET_FW_RULE_DIR_IN,
|
||||
Action: winfw.NET_FW_ACTION_ALLOW,
|
||||
Profiles: winfw.NET_FW_PROFILE2_ALL,
|
||||
}
|
||||
_, err = winfw.FirewallRuleAddAdvanced(rule)
|
||||
return err
|
||||
|
||||
17
tun.go
17
tun.go
@@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
@@ -24,7 +25,6 @@ type Handler interface {
|
||||
|
||||
type Tun interface {
|
||||
io.ReadWriter
|
||||
N.VectorisedWriter
|
||||
Name() (string, error)
|
||||
Start() error
|
||||
Close() error
|
||||
@@ -45,6 +45,12 @@ type LinuxTUN interface {
|
||||
TXChecksumOffload() bool
|
||||
}
|
||||
|
||||
type DarwinTUN interface {
|
||||
Tun
|
||||
BatchRead() ([]*buf.Buffer, error)
|
||||
BatchWrite(buffers []*buf.Buffer) error
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultIPRoute2TableIndex = 2022
|
||||
DefaultIPRoute2RuleIndex = 9000
|
||||
@@ -66,6 +72,8 @@ type Options struct {
|
||||
AutoRedirectMarkMode bool
|
||||
AutoRedirectInputMark uint32
|
||||
AutoRedirectOutputMark uint32
|
||||
Inet4LoopbackAddress []netip.Addr
|
||||
Inet6LoopbackAddress []netip.Addr
|
||||
StrictRoute bool
|
||||
Inet4RouteAddress []netip.Prefix
|
||||
Inet6RouteAddress []netip.Prefix
|
||||
@@ -88,6 +96,13 @@ type Options struct {
|
||||
|
||||
// For library usages.
|
||||
EXP_DisableDNSHijack bool
|
||||
|
||||
// For gvisor stack, it should be enabled when MTU is less than 32768; otherwise it should be less than or equal to 8192.
|
||||
// The above condition is just an estimate and not exact, calculated on M4 pro.
|
||||
EXP_MultiPendingPackets bool
|
||||
|
||||
// Will cause the darwin network to die, do not use.
|
||||
EXP_SendMsgX bool
|
||||
}
|
||||
|
||||
func (o *Options) Inet4GatewayAddr() netip.Addr {
|
||||
|
||||
277
tun_darwin.go
277
tun_darwin.go
@@ -10,26 +10,74 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
|
||||
"github.com/sagernet/sing-tun/internal/stopfd_darwin"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/shell"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var _ DarwinTUN = (*NativeTun)(nil)
|
||||
|
||||
const PacketOffset = 4
|
||||
|
||||
type NativeTun struct {
|
||||
tunFile *os.File
|
||||
tunWriter N.VectorisedWriter
|
||||
options Options
|
||||
inet4Address [4]byte
|
||||
inet6Address [16]byte
|
||||
routeSet bool
|
||||
tunFd int
|
||||
tunFile *os.File
|
||||
batchSize int
|
||||
iovecs []iovecBuffer
|
||||
iovecsOutput []iovecBuffer
|
||||
iovecsOutputDefault []unix.Iovec
|
||||
msgHdrs []rawfile.MsgHdrX
|
||||
msgHdrsOutput []rawfile.MsgHdrX
|
||||
buffers []*buf.Buffer
|
||||
stopFd stopfd.StopFD
|
||||
options Options
|
||||
inet4Address [4]byte
|
||||
inet6Address [16]byte
|
||||
routeSet bool
|
||||
sendMsgX bool
|
||||
}
|
||||
|
||||
type iovecBuffer struct {
|
||||
mtu int
|
||||
buffer *buf.Buffer
|
||||
iovecs []unix.Iovec
|
||||
}
|
||||
|
||||
func newIovecBuffer(mtu int) iovecBuffer {
|
||||
return iovecBuffer{
|
||||
mtu: mtu,
|
||||
iovecs: make([]unix.Iovec, 2),
|
||||
}
|
||||
}
|
||||
|
||||
func (b *iovecBuffer) nextIovecs() []unix.Iovec {
|
||||
if b.iovecs[0].Len == 0 {
|
||||
headBuffer := make([]byte, PacketOffset)
|
||||
b.iovecs[0].Base = &headBuffer[0]
|
||||
b.iovecs[0].SetLen(PacketOffset)
|
||||
}
|
||||
if b.buffer == nil {
|
||||
b.buffer = buf.NewSize(b.mtu)
|
||||
b.iovecs[1] = b.buffer.Iovec(b.buffer.Cap())
|
||||
}
|
||||
return b.iovecs
|
||||
}
|
||||
|
||||
func (b *iovecBuffer) nextIovecsOutput(buffer *buf.Buffer) []unix.Iovec {
|
||||
switch header.IPVersion(buffer.Bytes()) {
|
||||
case header.IPv4Version:
|
||||
b.iovecs[0] = packetHeaderVec4
|
||||
case header.IPv6Version:
|
||||
b.iovecs[0] = packetHeaderVec6
|
||||
}
|
||||
b.iovecs[1] = buffer.Iovec(buffer.Len())
|
||||
return b.iovecs
|
||||
}
|
||||
|
||||
func (t *NativeTun) Name() (string, error) {
|
||||
@@ -42,6 +90,7 @@ func (t *NativeTun) Name() (string, error) {
|
||||
|
||||
func New(options Options) (Tun, error) {
|
||||
var tunFd int
|
||||
batchSize := ((512 * 1024) / int(options.MTU)) + 1
|
||||
if options.FileDescriptor == 0 {
|
||||
ifIndex := -1
|
||||
_, err := fmt.Sscanf(options.Name, "utun%d", &ifIndex)
|
||||
@@ -54,18 +103,38 @@ func New(options Options) (Tun, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = configure(tunFd, ifIndex, options.Name, options)
|
||||
err = create(tunFd, ifIndex, options.Name, options)
|
||||
if err != nil {
|
||||
unix.Close(tunFd)
|
||||
return nil, err
|
||||
}
|
||||
err = configure(tunFd, options.EXP_MultiPendingPackets, batchSize)
|
||||
if err != nil {
|
||||
unix.Close(tunFd)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
tunFd = options.FileDescriptor
|
||||
err := configure(tunFd, options.EXP_MultiPendingPackets, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
nativeTun := &NativeTun{
|
||||
tunFile: os.NewFile(uintptr(tunFd), "utun"),
|
||||
options: options,
|
||||
tunFd: tunFd,
|
||||
tunFile: os.NewFile(uintptr(tunFd), "utun"),
|
||||
options: options,
|
||||
batchSize: batchSize,
|
||||
iovecs: make([]iovecBuffer, batchSize),
|
||||
iovecsOutput: make([]iovecBuffer, batchSize),
|
||||
msgHdrs: make([]rawfile.MsgHdrX, batchSize),
|
||||
msgHdrsOutput: make([]rawfile.MsgHdrX, batchSize),
|
||||
stopFd: common.Must1(stopfd.New()),
|
||||
sendMsgX: options.EXP_SendMsgX,
|
||||
}
|
||||
for i := 0; i < batchSize; i++ {
|
||||
nativeTun.iovecs[i] = newIovecBuffer(int(options.MTU))
|
||||
nativeTun.iovecsOutput[i] = newIovecBuffer(int(options.MTU))
|
||||
}
|
||||
if len(options.Inet4Address) > 0 {
|
||||
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
|
||||
@@ -73,21 +142,20 @@ func New(options Options) (Tun, error) {
|
||||
if len(options.Inet6Address) > 0 {
|
||||
nativeTun.inet6Address = options.Inet6Address[0].Addr().As16()
|
||||
}
|
||||
var ok bool
|
||||
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
|
||||
if !ok {
|
||||
panic("create vectorised writer")
|
||||
}
|
||||
return nativeTun, nil
|
||||
}
|
||||
|
||||
func (t *NativeTun) Start() error {
|
||||
t.options.InterfaceMonitor.RegisterMyInterface(t.options.Name)
|
||||
return t.setRoutes()
|
||||
}
|
||||
|
||||
func (t *NativeTun) Close() error {
|
||||
defer flushDNSCache()
|
||||
return E.Errors(t.unsetRoutes(), t.tunFile.Close())
|
||||
t.stopFd.Stop()
|
||||
err := E.Errors(t.unsetRoutes(), t.tunFile.Close())
|
||||
t.stopFd.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *NativeTun) Read(p []byte) (n int, err error) {
|
||||
@@ -99,19 +167,15 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
var (
|
||||
packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET}
|
||||
packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6}
|
||||
packetHeader4 = []byte{0x00, 0x00, 0x00, unix.AF_INET}
|
||||
packetHeader6 = []byte{0x00, 0x00, 0x00, unix.AF_INET6}
|
||||
packetHeaderVec4 = unix.Iovec{Base: &packetHeader4[0]}
|
||||
packetHeaderVec6 = unix.Iovec{Base: &packetHeader6[0]}
|
||||
)
|
||||
|
||||
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
var packetHeader []byte
|
||||
switch header.IPVersion(buffers[0].Bytes()) {
|
||||
case header.IPv4Version:
|
||||
packetHeader = packetHeader4[:]
|
||||
case header.IPv6Version:
|
||||
packetHeader = packetHeader6[:]
|
||||
}
|
||||
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
|
||||
func init() {
|
||||
packetHeaderVec4.SetLen(4)
|
||||
packetHeaderVec6.SetLen(4)
|
||||
}
|
||||
|
||||
const utunControlName = "com.apple.net.utun_control"
|
||||
@@ -146,7 +210,7 @@ type addrLifetime6 struct {
|
||||
Pltime uint32
|
||||
}
|
||||
|
||||
func configure(tunFd int, ifIndex int, name string, options Options) error {
|
||||
func create(tunFd int, ifIndex int, name string, options Options) error {
|
||||
ctlInfo := &unix.CtlInfo{}
|
||||
copy(ctlInfo.Name[:], utunControlName)
|
||||
err := unix.IoctlCtlInfo(tunFd, ctlInfo)
|
||||
@@ -162,11 +226,6 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
|
||||
return os.NewSyscallError("Connect", err)
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(tunFd, true)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SetNonblock", err)
|
||||
}
|
||||
|
||||
err = useSocket(unix.AF_INET, unix.SOCK_DGRAM, 0, func(socketFd int) error {
|
||||
var ifr unix.IfreqMTU
|
||||
copy(ifr.Name[:], name)
|
||||
@@ -258,6 +317,90 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func configure(tunFd int, multiPendingPackets bool, batchSize int) error {
|
||||
err := unix.SetNonblock(tunFd, true)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SetNonblock", err)
|
||||
}
|
||||
if multiPendingPackets {
|
||||
const UTUN_OPT_MAX_PENDING_PACKETS = 16
|
||||
err = unix.SetsockoptInt(tunFd, 2, UTUN_OPT_MAX_PENDING_PACKETS, batchSize)
|
||||
if err != nil {
|
||||
return os.NewSyscallError("SetsockoptInt UTUN_OPT_MAX_PENDING_PACKETS", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *NativeTun) BatchRead() ([]*buf.Buffer, error) {
|
||||
for i := 0; i < t.batchSize; i++ {
|
||||
iovecs := t.iovecs[i].nextIovecs()
|
||||
// Cannot clear only the length field. Older versions of the darwin kernel will check whether other data is empty.
|
||||
// https://github.com/Darm64/XNU/blob/xnu-2782.40.9/bsd/kern/uipc_syscalls.c#L2026-L2048
|
||||
t.msgHdrs[i] = rawfile.MsgHdrX{}
|
||||
t.msgHdrs[i].Msg.Iov = &iovecs[0]
|
||||
t.msgHdrs[i].Msg.Iovlen = 2
|
||||
}
|
||||
n, errno := rawfile.BlockingRecvMMsgUntilStopped(t.stopFd.ReadFD, t.tunFd, t.msgHdrs)
|
||||
if errno != 0 {
|
||||
for k := 0; k < n; k++ {
|
||||
t.iovecs[k].buffer.Release()
|
||||
t.iovecs[k].buffer = nil
|
||||
}
|
||||
t.buffers = t.buffers[:0]
|
||||
return nil, errno
|
||||
}
|
||||
if n < 0 {
|
||||
return nil, os.ErrClosed
|
||||
}
|
||||
if n < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
buffers := t.buffers
|
||||
for k := 0; k < n; k++ {
|
||||
buffer := t.iovecs[k].buffer
|
||||
t.iovecs[k].buffer = nil
|
||||
buffer.Truncate(int(t.msgHdrs[k].DataLen) - PacketOffset)
|
||||
buffers = append(buffers, buffer)
|
||||
}
|
||||
t.buffers = buffers[:0]
|
||||
return buffers, nil
|
||||
}
|
||||
|
||||
func (t *NativeTun) BatchWrite(buffers []*buf.Buffer) error {
|
||||
if !t.sendMsgX {
|
||||
for i, buffer := range buffers {
|
||||
t.iovecsOutput[i].nextIovecsOutput(buffer)
|
||||
}
|
||||
for i := range buffers {
|
||||
errno := rawfile.NonBlockingWriteIovec(t.tunFd, t.iovecsOutput[i].iovecs)
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i, buffer := range buffers {
|
||||
iovecs := t.iovecsOutput[i].nextIovecsOutput(buffer)
|
||||
t.msgHdrsOutput[i] = rawfile.MsgHdrX{}
|
||||
t.msgHdrsOutput[i].Msg.Iov = &iovecs[0]
|
||||
t.msgHdrsOutput[i].Msg.Iovlen = 2
|
||||
}
|
||||
var n int
|
||||
for n != len(buffers) {
|
||||
sent, errno := rawfile.NonBlockingSendMMsg(t.tunFd, t.msgHdrsOutput[n:len(buffers)])
|
||||
if errno != 0 {
|
||||
return errno
|
||||
}
|
||||
n += sent
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *NativeTun) TXChecksumOffload() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
|
||||
err := t.unsetRoutes()
|
||||
if err != nil {
|
||||
@@ -268,45 +411,47 @@ func (t *NativeTun) UpdateRouteOptions(tunOptions Options) error {
|
||||
}
|
||||
|
||||
func (t *NativeTun) setRoutes() error {
|
||||
if t.options.AutoRoute && t.options.FileDescriptor == 0 {
|
||||
if t.options.FileDescriptor == 0 {
|
||||
routeRanges, err := t.options.BuildAutoRouteRanges(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
|
||||
for _, destination := range routeRanges {
|
||||
var gateway netip.Addr
|
||||
if destination.Addr().Is4() {
|
||||
gateway = gateway4
|
||||
} else {
|
||||
gateway = gateway6
|
||||
}
|
||||
var interfaceIndex int
|
||||
if t.options.InterfaceScope {
|
||||
iff, err := t.options.InterfaceFinder.ByName(t.options.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iff.Index
|
||||
}
|
||||
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
|
||||
if err != nil {
|
||||
return E.Cause(err, "remove existing route: ", destination)
|
||||
}
|
||||
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
|
||||
if err != nil {
|
||||
return E.Cause(err, "re-add route: ", destination)
|
||||
}
|
||||
if len(routeRanges) > 0 {
|
||||
gateway4, gateway6 := t.options.Inet4GatewayAddr(), t.options.Inet6GatewayAddr()
|
||||
for _, destination := range routeRanges {
|
||||
var gateway netip.Addr
|
||||
if destination.Addr().Is4() {
|
||||
gateway = gateway4
|
||||
} else {
|
||||
return E.Cause(err, "add route: ", destination)
|
||||
gateway = gateway6
|
||||
}
|
||||
var interfaceIndex int
|
||||
if t.options.InterfaceScope {
|
||||
iff, err := t.options.InterfaceFinder.ByName(t.options.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
interfaceIndex = iff.Index
|
||||
}
|
||||
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EEXIST) {
|
||||
err = execRoute(unix.RTM_DELETE, false, 0, destination, gateway)
|
||||
if err != nil {
|
||||
return E.Cause(err, "remove existing route: ", destination)
|
||||
}
|
||||
err = execRoute(unix.RTM_ADD, t.options.InterfaceScope, interfaceIndex, destination, gateway)
|
||||
if err != nil {
|
||||
return E.Cause(err, "re-add route: ", destination)
|
||||
}
|
||||
} else {
|
||||
return E.Cause(err, "add route: ", destination)
|
||||
}
|
||||
}
|
||||
}
|
||||
flushDNSCache()
|
||||
t.routeSet = true
|
||||
}
|
||||
flushDNSCache()
|
||||
t.routeSet = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,132 +3,57 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/link/qdisc/fifo"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing-tun/internal/fdbased_darwin"
|
||||
"github.com/sagernet/sing-tun/internal/rawfile_darwin"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var _ GVisorTun = (*NativeTun)(nil)
|
||||
|
||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||
return &DarwinEndpoint{tun: t}, nil
|
||||
}
|
||||
|
||||
var _ stack.LinkEndpoint = (*DarwinEndpoint)(nil)
|
||||
|
||||
type DarwinEndpoint struct {
|
||||
tun *NativeTun
|
||||
dispatcher stack.NetworkDispatcher
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) MTU() uint32 {
|
||||
return e.tun.options.MTU
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) SetMTU(mtu uint32) {
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) MaxHeaderLength() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) LinkAddress() tcpip.LinkAddress {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
return stack.CapabilityRXChecksumOffload
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
if dispatcher == nil && e.dispatcher != nil {
|
||||
e.dispatcher = nil
|
||||
return
|
||||
func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) {
|
||||
iovecs := t.iovecsOutputDefault
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
iovecs = append(iovecs, packetHeaderVec4)
|
||||
} else {
|
||||
iovecs = append(iovecs, packetHeaderVec6)
|
||||
}
|
||||
if dispatcher != nil && e.dispatcher == nil {
|
||||
e.dispatcher = dispatcher
|
||||
go e.dispatchLoop()
|
||||
var dataLen int
|
||||
for _, packetSlice := range pkt.AsSlices() {
|
||||
dataLen += len(packetSlice)
|
||||
iovec := unix.Iovec{
|
||||
Base: &packetSlice[0],
|
||||
}
|
||||
iovec.SetLen(len(packetSlice))
|
||||
iovecs = append(iovecs, iovec)
|
||||
}
|
||||
if cap(iovecs) > cap(t.iovecsOutputDefault) {
|
||||
t.iovecsOutputDefault = iovecs[:0]
|
||||
}
|
||||
errno := rawfile.NonBlockingWriteIovec(t.tunFd, iovecs)
|
||||
if errno == 0 {
|
||||
return dataLen, nil
|
||||
} else {
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) dispatchLoop() {
|
||||
packetBuffer := make([]byte, e.tun.options.MTU+PacketOffset)
|
||||
for {
|
||||
n, err := e.tun.tunFile.Read(packetBuffer)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
packet := packetBuffer[PacketOffset:n]
|
||||
var networkProtocol tcpip.NetworkProtocolNumber
|
||||
switch header.IPVersion(packet) {
|
||||
case header.IPv4Version:
|
||||
networkProtocol = header.IPv4ProtocolNumber
|
||||
if header.IPv4(packet).DestinationAddress().As4() == e.tun.inet4Address {
|
||||
e.tun.tunFile.Write(packetBuffer[:n])
|
||||
continue
|
||||
}
|
||||
case header.IPv6Version:
|
||||
networkProtocol = header.IPv6ProtocolNumber
|
||||
if header.IPv6(packet).DestinationAddress().As16() == e.tun.inet6Address {
|
||||
e.tun.tunFile.Write(packetBuffer[:n])
|
||||
continue
|
||||
}
|
||||
default:
|
||||
e.tun.tunFile.Write(packetBuffer[:n])
|
||||
continue
|
||||
}
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packetBuffer[4:n]),
|
||||
IsForwardedPacket: true,
|
||||
})
|
||||
pkt.NetworkProtocolNumber = networkProtocol
|
||||
dispatcher := e.dispatcher
|
||||
if dispatcher == nil {
|
||||
pkt.DecRef()
|
||||
return
|
||||
}
|
||||
dispatcher.DeliverNetworkPacket(networkProtocol, pkt)
|
||||
pkt.DecRef()
|
||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) {
|
||||
ep, err := fdbased.New(&fdbased.Options{
|
||||
FDs: []int{t.tunFd},
|
||||
MTU: t.options.MTU,
|
||||
RXChecksumOffload: true,
|
||||
PacketDispatchMode: fdbased.RecvMMsg,
|
||||
MultiPendingPackets: t.options.EXP_MultiPendingPackets,
|
||||
SendMsgX: t.options.EXP_SendMsgX,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, stack.NICOptions{}, err
|
||||
}
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) IsAttached() bool {
|
||||
return e.dispatcher != nil
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) Wait() {
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) ARPHardwareType() header.ARPHardwareType {
|
||||
return header.ARPHardwareNone
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) AddHeader(buffer *stack.PacketBuffer) {
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {
|
||||
var n int
|
||||
for _, packet := range packetBufferList.AsSlice() {
|
||||
_, err := bufio.WriteVectorised(e.tun, packet.AsSlices())
|
||||
if err != nil {
|
||||
return n, &tcpip.ErrAborted{}
|
||||
}
|
||||
n++
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) Close() {
|
||||
}
|
||||
|
||||
func (e *DarwinEndpoint) SetOnCloseAction(f func()) {
|
||||
return ep, stack.NICOptions{
|
||||
QDisc: fifo.New(ep, 1, 1000),
|
||||
}, nil
|
||||
}
|
||||
|
||||
133
tun_linux.go
133
tun_linux.go
@@ -18,10 +18,8 @@ import (
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/rw"
|
||||
"github.com/sagernet/sing/common/shell"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
@@ -32,22 +30,22 @@ import (
|
||||
var _ LinuxTUN = (*NativeTun)(nil)
|
||||
|
||||
type NativeTun struct {
|
||||
tunFd int
|
||||
tunFile *os.File
|
||||
tunWriter N.VectorisedWriter
|
||||
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
|
||||
options Options
|
||||
ruleIndex6 []int
|
||||
readAccess sync.Mutex
|
||||
writeAccess sync.Mutex
|
||||
vnetHdr bool
|
||||
writeBuffer []byte
|
||||
gsoToWrite []int
|
||||
tcpGROTable *tcpGROTable
|
||||
udpGroAccess sync.Mutex
|
||||
udpGROTable *udpGROTable
|
||||
gro groDisablementFlags
|
||||
txChecksumOffload bool
|
||||
tunFd int
|
||||
tunFile *os.File
|
||||
iovecsOutputDefault []unix.Iovec
|
||||
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
|
||||
options Options
|
||||
ruleIndex6 []int
|
||||
readAccess sync.Mutex
|
||||
writeAccess sync.Mutex
|
||||
vnetHdr bool
|
||||
writeBuffer []byte
|
||||
gsoToWrite []int
|
||||
tcpGROTable *tcpGROTable
|
||||
udpGroAccess sync.Mutex
|
||||
udpGROTable *udpGROTable
|
||||
gro groDisablementFlags
|
||||
txChecksumOffload bool
|
||||
}
|
||||
|
||||
func New(options Options) (Tun, error) {
|
||||
@@ -77,11 +75,6 @@ func New(options Options) (Tun, error) {
|
||||
options: options,
|
||||
}
|
||||
}
|
||||
var ok bool
|
||||
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
|
||||
if !ok {
|
||||
panic("create vectorised writer")
|
||||
}
|
||||
return nativeTun, nil
|
||||
}
|
||||
|
||||
@@ -137,7 +130,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||
for _, address := range t.options.Inet4Address {
|
||||
addr4, _ := netlink.ParseAddr(address.String())
|
||||
err = netlink.AddrAdd(tunLink, addr4)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, unix.EEXIST) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -146,7 +139,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
|
||||
for _, address := range t.options.Inet6Address {
|
||||
addr6, _ := netlink.ParseAddr(address.String())
|
||||
err = netlink.AddrAdd(tunLink, addr6)
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, unix.EEXIST) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -202,7 +195,6 @@ func (t *NativeTun) enableGSO() error {
|
||||
err = setUDPOffload(t.tunFd)
|
||||
if err != nil {
|
||||
t.gro.disableUDPGRO()
|
||||
return E.Cause(err, "enable UDP offload")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -265,7 +257,7 @@ func (t *NativeTun) Start() error {
|
||||
if t.options.FileDescriptor != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
t.options.InterfaceMonitor.RegisterMyInterface(t.options.Name)
|
||||
tunLink, err := netlink.LinkByName(t.options.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -323,6 +315,7 @@ func (t *NativeTun) Close() error {
|
||||
if t.interfaceCallback != nil {
|
||||
t.options.InterfaceMonitor.UnregisterCallback(t.interfaceCallback)
|
||||
}
|
||||
t.unsetAddresses()
|
||||
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
|
||||
}
|
||||
|
||||
@@ -403,20 +396,6 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
|
||||
return t.tunFile.Write(p)
|
||||
}
|
||||
|
||||
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
if t.vnetHdr {
|
||||
n := buf.LenMulti(buffers)
|
||||
buffer := buf.NewSize(virtioNetHdrLen + n)
|
||||
buffer.Truncate(virtioNetHdrLen)
|
||||
buf.CopyMulti(buffer.Extend(n), buffers)
|
||||
_, err := t.tunFile.Write(buffer.Bytes())
|
||||
buffer.Release()
|
||||
return err
|
||||
} else {
|
||||
return t.tunWriter.WriteVectorised(buffers)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NativeTun) FrontHeadroom() int {
|
||||
if t.vnetHdr {
|
||||
return virtioNetHdrLen
|
||||
@@ -638,6 +617,23 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
}
|
||||
// Fallback rules after system default rules (32766: main, 32767: default)
|
||||
// Only reached when main and default tables have no route
|
||||
const fallbackPriority = 32768
|
||||
if p4 {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = fallbackPriority
|
||||
it.Table = t.options.IPRoute2TableIndex
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p6 {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = fallbackPriority
|
||||
it.Table = t.options.IPRoute2TableIndex
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
}
|
||||
return rules
|
||||
}
|
||||
|
||||
@@ -669,7 +665,7 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
}
|
||||
}
|
||||
if len(t.options.IncludeInterface) > 0 {
|
||||
matchPriority := priority + 2*len(t.options.IncludeInterface) + 1
|
||||
matchPriority := priority + 2
|
||||
for _, includeInterface := range t.options.IncludeInterface {
|
||||
if p4 {
|
||||
it = netlink.NewRule()
|
||||
@@ -678,7 +674,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Goto = matchPriority
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
priority++
|
||||
}
|
||||
if p6 {
|
||||
it = netlink.NewRule()
|
||||
@@ -687,9 +682,14 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Goto = matchPriority
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
priority6++
|
||||
}
|
||||
}
|
||||
if p4 {
|
||||
priority++
|
||||
}
|
||||
if p6 {
|
||||
priority6++
|
||||
}
|
||||
if p4 {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority
|
||||
@@ -727,7 +727,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
priority++
|
||||
}
|
||||
if p6 {
|
||||
it = netlink.NewRule()
|
||||
@@ -736,9 +735,15 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
priority6++
|
||||
}
|
||||
}
|
||||
|
||||
if p4 {
|
||||
priority++
|
||||
}
|
||||
if p6 {
|
||||
priority6++
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" && t.options.InterfaceMonitor.AndroidVPNEnabled() {
|
||||
@@ -829,14 +834,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p4 && !t.options.StrictRoute {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority
|
||||
it.IPProto = syscall.IPPROTO_ICMP
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p6 {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority6
|
||||
@@ -847,16 +844,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
}
|
||||
|
||||
if p6 && !t.options.StrictRoute {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority6
|
||||
it.IPProto = syscall.IPPROTO_ICMPV6
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
priority6++
|
||||
}
|
||||
}
|
||||
if p4 {
|
||||
it = netlink.NewRule()
|
||||
@@ -1034,6 +1021,24 @@ func (t *NativeTun) unsetRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *NativeTun) unsetAddresses() {
|
||||
if t.options.FileDescriptor > 0 {
|
||||
return
|
||||
}
|
||||
tunLink, err := netlink.LinkByName(t.options.Name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, address := range t.options.Inet4Address {
|
||||
addr, _ := netlink.ParseAddr(address.String())
|
||||
_ = netlink.AddrDel(tunLink, addr)
|
||||
}
|
||||
for _, address := range t.options.Inet6Address {
|
||||
addr, _ := netlink.ParseAddr(address.String())
|
||||
_ = netlink.AddrDel(tunLink, addr)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NativeTun) resetRules() error {
|
||||
t.unsetRules()
|
||||
return t.setRules()
|
||||
|
||||
@@ -3,15 +3,44 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"github.com/sagernet/gvisor/pkg/rawfile"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/link/fdbased"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fdbased.BufConfig = []int{65535}
|
||||
}
|
||||
|
||||
var _ GVisorTun = (*NativeTun)(nil)
|
||||
|
||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||
func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) {
|
||||
iovecs := t.iovecsOutputDefault
|
||||
var dataLen int
|
||||
for _, packetSlice := range pkt.AsSlices() {
|
||||
dataLen += len(packetSlice)
|
||||
iovec := unix.Iovec{
|
||||
Base: &packetSlice[0],
|
||||
}
|
||||
iovec.SetLen(len(packetSlice))
|
||||
iovecs = append(iovecs, iovec)
|
||||
}
|
||||
if cap(iovecs) > cap(t.iovecsOutputDefault) {
|
||||
t.iovecsOutputDefault = iovecs[:0]
|
||||
}
|
||||
errno := rawfile.NonBlockingWriteIovec(t.tunFd, iovecs)
|
||||
if errno == 0 {
|
||||
return dataLen, nil
|
||||
} else {
|
||||
return 0, errno
|
||||
}
|
||||
}
|
||||
|
||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) {
|
||||
if t.vnetHdr {
|
||||
return fdbased.New(&fdbased.Options{
|
||||
ep, err := fdbased.New(&fdbased.Options{
|
||||
FDs: []int{t.tunFd},
|
||||
MTU: t.options.MTU,
|
||||
GSOMaxSize: gsoMaxSize,
|
||||
@@ -19,11 +48,20 @@ func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||
RXChecksumOffload: true,
|
||||
TXChecksumOffload: t.txChecksumOffload,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, stack.NICOptions{}, err
|
||||
}
|
||||
return ep, stack.NICOptions{}, nil
|
||||
} else {
|
||||
ep, err := fdbased.New(&fdbased.Options{
|
||||
FDs: []int{t.tunFd},
|
||||
MTU: t.options.MTU,
|
||||
RXChecksumOffload: true,
|
||||
TXChecksumOffload: t.txChecksumOffload,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, stack.NICOptions{}, err
|
||||
}
|
||||
return ep, stack.NICOptions{}, nil
|
||||
}
|
||||
return fdbased.New(&fdbased.Options{
|
||||
FDs: []int{t.tunFd},
|
||||
MTU: t.options.MTU,
|
||||
RXChecksumOffload: true,
|
||||
TXChecksumOffload: t.txChecksumOffload,
|
||||
})
|
||||
}
|
||||
|
||||
66
tun_rules.go
66
tun_rules.go
@@ -108,7 +108,7 @@ const autoRouteUseSubRanges = runtime.GOOS == "darwin"
|
||||
|
||||
func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) {
|
||||
var routeRanges []netip.Prefix
|
||||
if o.AutoRoute && len(o.Inet4Address) > 0 {
|
||||
if len(o.Inet4Address) > 0 {
|
||||
var inet4Ranges []netip.Prefix
|
||||
if len(o.Inet4RouteAddress) > 0 {
|
||||
inet4Ranges = o.Inet4RouteAddress
|
||||
@@ -119,19 +119,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if autoRouteUseSubRanges && !underNetworkExtension {
|
||||
inet4Ranges = []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
|
||||
} else if o.AutoRoute {
|
||||
if autoRouteUseSubRanges && !underNetworkExtension {
|
||||
inet4Ranges = []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
|
||||
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
|
||||
}
|
||||
} else {
|
||||
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
} else if runtime.GOOS == "darwin" {
|
||||
for _, address := range o.Inet4Address {
|
||||
if address.Bits() < 32 {
|
||||
inet4Ranges = append(inet4Ranges, address.Masked())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
|
||||
}
|
||||
if len(o.Inet4RouteExcludeAddress) == 0 {
|
||||
routeRanges = append(routeRanges, inet4Ranges...)
|
||||
@@ -161,19 +169,27 @@ func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Pref
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if autoRouteUseSubRanges && !underNetworkExtension {
|
||||
inet6Ranges = []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
|
||||
} else if o.AutoRoute {
|
||||
if autoRouteUseSubRanges && !underNetworkExtension {
|
||||
inet6Ranges = []netip.Prefix{
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
|
||||
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
|
||||
}
|
||||
} else {
|
||||
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
} else if runtime.GOOS == "darwin" {
|
||||
for _, address := range o.Inet6Address {
|
||||
if address.Bits() < 32 {
|
||||
inet6Ranges = append(inet6Ranges, address.Masked())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
|
||||
}
|
||||
if len(o.Inet6RouteExcludeAddress) == 0 {
|
||||
routeRanges = append(routeRanges, inet6Ranges...)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
@@ -16,8 +17,6 @@ import (
|
||||
"github.com/sagernet/sing-tun/internal/winsys"
|
||||
"github.com/sagernet/sing-tun/internal/wintun"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/windnsapi"
|
||||
|
||||
@@ -163,6 +162,7 @@ func (t *NativeTun) Name() (string, error) {
|
||||
}
|
||||
|
||||
func (t *NativeTun) Start() error {
|
||||
t.options.InterfaceMonitor.RegisterMyInterface(t.options.Name)
|
||||
if !t.options.AutoRoute {
|
||||
return nil
|
||||
}
|
||||
@@ -181,6 +181,13 @@ func (t *NativeTun) Start() error {
|
||||
return err
|
||||
}
|
||||
if t.options.StrictRoute {
|
||||
major, _, _ := windows.RtlGetNtVersionNumbers()
|
||||
if major < 10 {
|
||||
if t.options.Logger != nil {
|
||||
t.options.Logger.Warn("strict routing is not supported on Windows versions below 10")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
var engine uintptr
|
||||
session := &winsys.FWPM_SESSION0{Flags: winsys.FWPM_SESSION_FLAG_DYNAMIC}
|
||||
err := winsys.FwpmEngineOpen0(nil, winsys.RPC_C_AUTHN_DEFAULT, nil, session, unsafe.Pointer(&engine))
|
||||
@@ -395,15 +402,16 @@ retry:
|
||||
|
||||
func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
|
||||
t.running.Add(1)
|
||||
defer t.running.Done()
|
||||
retry:
|
||||
if t.close.Load() == 1 {
|
||||
t.running.Done()
|
||||
return nil, nil, os.ErrClosed
|
||||
}
|
||||
start := nanotime()
|
||||
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
|
||||
for {
|
||||
if t.close.Load() == 1 {
|
||||
t.running.Done()
|
||||
return nil, nil, os.ErrClosed
|
||||
}
|
||||
packet, err := t.session.ReceivePacket()
|
||||
@@ -411,7 +419,10 @@ retry:
|
||||
case nil:
|
||||
packetSize := len(packet)
|
||||
t.rate.update(uint64(packetSize))
|
||||
return packet, func() { t.session.ReleaseReceivePacket(packet) }, nil
|
||||
return packet, func() {
|
||||
t.session.ReleaseReceivePacket(packet)
|
||||
t.running.Done()
|
||||
}, nil
|
||||
case windows.ERROR_NO_MORE_ITEMS:
|
||||
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||
windows.WaitForSingleObject(t.readWait, windows.INFINITE)
|
||||
@@ -420,10 +431,13 @@ retry:
|
||||
procyield(1)
|
||||
continue
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
t.running.Done()
|
||||
return nil, nil, os.ErrClosed
|
||||
case windows.ERROR_INVALID_DATA:
|
||||
t.running.Done()
|
||||
return nil, nil, errors.New("send ring corrupt")
|
||||
}
|
||||
t.running.Done()
|
||||
return nil, nil, fmt.Errorf("read failed: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -516,11 +530,6 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
|
||||
return 0, fmt.Errorf("write failed: %w", err)
|
||||
}
|
||||
|
||||
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
|
||||
defer buf.ReleaseMulti(buffers)
|
||||
return common.Error(t.write(buf.ToSliceMulti(buffers)))
|
||||
}
|
||||
|
||||
func (t *NativeTun) Close() error {
|
||||
var err error
|
||||
t.closeOnce.Do(func() {
|
||||
|
||||
@@ -11,8 +11,12 @@ import (
|
||||
|
||||
var _ GVisorTun = (*NativeTun)(nil)
|
||||
|
||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
|
||||
return &WintunEndpoint{tun: t}, nil
|
||||
func (t *NativeTun) WritePacket(pkt *stack.PacketBuffer) (int, error) {
|
||||
return t.write(pkt.AsSlices())
|
||||
}
|
||||
|
||||
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, stack.NICOptions, error) {
|
||||
return &WintunEndpoint{tun: t}, stack.NICOptions{}, nil
|
||||
}
|
||||
|
||||
var _ stack.LinkEndpoint = (*WintunEndpoint)(nil)
|
||||
|
||||
Reference in New Issue
Block a user