Compare commits

..

81 Commits

Author SHA1 Message Date
世界
4efde6372e Do not submit EventNoRoute repeatedly 2024-06-06 22:29:41 +08:00
世界
1c6d2891ab Fix "Fix darwin routes" 2024-06-06 22:29:38 +08:00
世界
5d9bd04495 Update dependencies 2024-06-06 22:29:34 +08:00
世界
5bf54dc69a Fix netlink 2024-05-23 14:54:56 +08:00
世界
840f3758f9 Prioritize *_route_address in linux auto-route 2024-05-20 22:19:46 +08:00
XYenon
d923e5d10a Fix darwin routes 2024-05-20 19:07:24 +08:00
世界
779d1c7db2 Fix linux auto-route sequence 2024-05-15 22:22:08 +08:00
世界
3f128a4a6a Fix Remove bad suppress_prefixlength iproute2 rule 2024-05-10 17:50:56 +08:00
世界
fb6e917a2c Add StackOptions.IncludeAllNetworks 2024-05-07 20:20:44 +08:00
世界
e272ff0ad3 Remove bad suppress_prefixlength iproute2 rule
This change gives tun priority over DHCP 121 rules
2024-05-07 19:52:59 +08:00
世界
5584917e52 Update gVisor to 20240422.0 2024-05-07 19:52:59 +08:00
世界
e0ddbbb84f Update gVisor to 20240212.0-65-g71212d503 2024-05-07 19:52:59 +08:00
世界
a9895a7d88 Update gVisor to 20240206.0 2024-05-07 19:52:59 +08:00
世界
9380493c39 Fix bad usage for exec 2024-05-07 19:52:44 +08:00
世界
63f6630a0a Fix darwin monitor 2024-05-03 15:33:14 +08:00
世界
d174625727 Update dependencies 2024-04-06 22:39:43 +08:00
世界
520d1bc9bb Remove dependency on comshim 2024-04-06 22:38:48 +08:00
wwqgtxx
fc63ec9388 avoid netlink dos networkUpdateMonitor 2024-04-06 22:23:48 +08:00
世界
cddf60537d Fix timer usage for monitor check update 2024-04-02 22:53:57 +08:00
世界
8bfb64cf04 Fix GSO batch size 2024-03-22 14:52:39 +08:00
世界
689e60891c Fix darwin monitor 2024-03-14 13:37:55 +08:00
世界
6ef2a6cdaa Fix deadlock on network update 2024-02-26 22:52:11 +08:00
世界
8d285f70fb Fix darwin monitor 2024-02-26 13:22:23 +08:00
世界
e8633c66d2 Update .gitignore 2024-02-26 13:22:15 +08:00
世界
951af3ca7a Update dependencies 2024-02-26 13:21:30 +08:00
世界
9b7c2a0a3c Fix unaligned panic on windows 2024-02-10 21:17:30 +08:00
世界
38c945fec5 Remove duplicated rules 2024-02-02 14:29:06 +08:00
世界
a276461b88 Update dependencies 2024-01-07 16:43:45 +08:00
世界
ebb3908ecf Fix bind forwarder to interface for systems stack 2023-12-21 16:51:25 +08:00
世界
5b50c61b72 Add GSO support 2023-12-21 16:51:06 +08:00
世界
fa89d2c0a5 Update gVisor to 20231204.0 2023-12-12 14:09:03 +08:00
世界
cad35277a2 Update unbind packet conn usage 2023-12-12 14:09:03 +08:00
世界
69c3b72eec FIx error handling for netlink banned in Android 2023-12-12 13:52:55 +08:00
世界
62f2d98190 Fix auto-route IPv6 on darwin 2023-12-11 21:37:40 +08:00
世界
d275b4a0fd Update dependencies 2023-12-11 21:37:40 +08:00
世界
46adeb9b5d Update dependencies 2023-12-04 21:00:08 +08:00
世界
f6ea97c5af Update gVisor to 20231113.0 2023-11-19 11:55:13 +08:00
世界
3fa4ee409a Add route exclude support 2023-11-16 18:27:36 +08:00
世界
958d6a25a4 Fix "Fix Linux IPv6 auto route rules" 2023-11-16 18:21:14 +08:00
世界
86322a3fe1 Update dependency 2023-11-14 19:28:42 +08:00
世界
78e0dfa18f Fix broadcast filter not applied to mixed stack 2023-11-14 17:16:24 +08:00
世界
ce9c864d89 system: Check UDP invalid packet 2023-11-14 17:16:24 +08:00
世界
da350ecc72 Add broadcast filter 2023-11-06 21:33:35 +08:00
世界
1a00992d06 Update dependencies 2023-11-05 16:08:15 +08:00
世界
150b116231 Add multicast filter 2023-11-04 08:01:41 +08:00
世界
b93db9639d Remove network requirement on start 2023-11-03 10:59:51 +08:00
i40e
efd9884154 Disable Windows DNS registration 2023-10-26 14:08:25 +08:00
世界
1a85bd3ef4 Add DefaultInterface func 2023-10-26 14:08:21 +08:00
世界
56a9b85cf5 Update gVisor to 20230814.0 2023-10-26 14:08:19 +08:00
世界
660222a0dd Fix Linux IPv6 auto route rules 2023-10-26 11:59:14 +08:00
世界
fee2614ae3 Update dependencies 2023-10-26 11:58:57 +08:00
世界
2b625a47c0 Update dependencies 2023-10-06 17:07:20 +08:00
世界
dcf7d50379 Fix gVisor UDP 6to4 check 2023-10-06 17:07:04 +08:00
世界
4979f75513 Update dependencies 2023-09-30 22:33:39 +08:00
世界
2a0a0ab228 android: Fix netlink check 2023-09-26 17:39:31 +08:00
世界
8adce0ea02 android: Check netlink available on monitor create 2023-09-25 17:15:15 +08:00
世界
b6d323004e Remove use of Write Unreachable as SendRejectionError panics when passing invalid packet 2023-09-22 11:50:04 +08:00
世界
e212724bac Update dependencies 2023-09-20 22:18:32 +08:00
世界
9c933ea553 Remove defaultInterfaceName when no route 2023-09-20 14:08:16 +08:00
世界
7545dc2d56 Fix darwin monitor socket leak 2023-08-21 14:55:22 +08:00
世界
db70908d61 Update dependencies 2023-08-20 17:19:22 +08:00
世界
824b903ebd Add [include/exclude]_interface iproute2 options 2023-08-20 17:18:29 +08:00
世界
10d98f2679 Add mixed stack 2023-08-12 19:38:06 +08:00
世界
aa8760b454 Add handshake interface support for gVisor UDP 2023-08-07 20:32:32 +08:00
世界
0a68b9f1d8 Fix monitor 2023-08-07 20:31:52 +08:00
世界
59b86002c4 Add no route event 2023-08-07 19:51:07 +08:00
世界
688d4da4b7 Fix gVisor UDP 2023-07-25 08:12:56 +08:00
世界
28db424ae8 Update dependencies 2023-07-23 14:16:08 +08:00
世界
bbf542f01a Improve gVisor UDP 2023-07-23 14:01:36 +08:00
世界
fd850d00e5 Fix buffer usage 2023-07-03 21:44:24 +08:00
世界
3b558f113c Update gVisor to 20230621.0 2023-06-27 11:12:05 +08:00
世界
d51abeb6c7 Prevent panic when write packet with bad address type 2023-06-21 13:27:17 +08:00
世界
323b9564f0 Update dependencies 2023-06-17 12:11:41 +08:00
世界
4bc8dc7f27 Configure default-route for systemd-resolved
Even though the documentation says this parameter doesn't matter, some people have reported that not configuring it can cause problems.
2023-06-12 19:28:29 +08:00
世界
41b2639e13 Update gVisor to 20230605.0-33-g8ec8dbe7e 2023-06-11 22:06:33 +08:00
世界
605266e65e Update gci usage 2023-06-10 08:50:20 +08:00
世界
e881f21013 Update gVisor to release-20230605.0-21-g457c1c36d 2023-06-10 08:45:55 +08:00
dyhkwong
b02f252916 Use api to create windows firewall rules 2023-05-20 12:11:00 +08:00
世界
91df97aee2 Fix macos monitor 2023-05-09 18:20:26 +08:00
世界
6999634511 Fix windows firewall for system stack 2023-05-09 12:12:00 +08:00
世界
209ec123ca Update gVisor to 20230417.0 2023-04-22 20:14:32 +08:00
50 changed files with 2954 additions and 1476 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env bash
PROJECTS=$(dirname "$0")/../..
go get -x github.com/sagernet/sing@$(git -C $PROJECTS/sing rev-parse HEAD)
go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD)
go mod tidy

1
.gitignore vendored
View File

@@ -1,2 +1,3 @@
/.idea/
/vendor/
.DS_Store

View File

@@ -1,7 +1,16 @@
build:
GOOS=darwin GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=ios GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=amd64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=386 go build -v -tags with_gvisor .
GOOS=linux GOARCH=arm go build -v -tags with_gvisor .
GOOS=windows GOARCH=amd64 go build -v -tags with_gvisor .
fmt:
@gofumpt -l -w .
@gofmt -s -w .
@gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" .
@gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" .
fmt_install:
go install -v mvdan.cc/gofumpt@latest

19
go.mod
View File

@@ -3,17 +3,18 @@ module github.com/sagernet/sing-tun
go 1.18
require (
github.com/fsnotify/fsnotify v1.6.0
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97
github.com/sagernet/sing v0.2.4
golang.org/x/net v0.9.0
golang.org/x/sys v0.7.0
gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c
github.com/fsnotify/fsnotify v1.7.0
github.com/go-ole/go-ole v1.3.0
github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f
github.com/sagernet/netlink v0.0.0-20240523065131-45e60152f9ba
github.com/sagernet/sing v0.4.1
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/net v0.26.0
golang.org/x/sys v0.21.0
)
require (
github.com/google/btree v1.0.1 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
golang.org/x/time v0.5.0 // indirect
)

46
go.sum
View File

@@ -1,24 +1,28 @@
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.2.4 h1:gC8BR5sglbJZX23RtMyFa8EETP9YEUADhfbEzU1yVbo=
github.com/sagernet/sing v0.2.4/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f h1:NkhuupzH5ch7b/Y/6ZHJWrnNLoiNnSJaow6DPb8VW2I=
github.com/sagernet/gvisor v0.0.0-20240428053021-e691de28565f/go.mod h1:KXmw+ouSJNOsuRpg4wgwwCQuunrGz4yoAqQjsLjc6N0=
github.com/sagernet/netlink v0.0.0-20240523065131-45e60152f9ba h1:EY5AS7CCtfmARNv2zXUOrsEMPFDGYxaw65JzA2p51Vk=
github.com/sagernet/netlink v0.0.0-20240523065131-45e60152f9ba/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/sing v0.4.1 h1:zVlpE+7k7AFoC2pv6ReqLf0PIHjihL/jsBl5k05PQFk=
github.com/sagernet/sing v0.4.1/go.mod h1:ieZHA/+Y9YZfXs2I3WtuwgyCZ6GPsIR7HdKb1SdEnls=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h1:m5lcgWnL3OElQNVyp3qcncItJ2c0sQlSGjYK2+nJTA4=
gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c/go.mod h1:TIvkJD0sxe8pIob3p6T8IzxXunlp6yfgktvTNp+DGNM=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

View File

@@ -1,119 +0,0 @@
//go:build with_gvisor
package tun
import (
"context"
"math"
"net"
"net/netip"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort]
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
}
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.RemoteAddress)), id.RemotePort)
upstreamMetadata.Destination = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.LocalAddress)), id.LocalPort)
var netProto tcpip.NetworkProtocolNumber
if upstreamMetadata.Source.IsIPv4() {
netProto = header.IPv4ProtocolNumber
} else {
netProto = header.IPv6ProtocolNumber
}
f.udpNat.NewPacket(
f.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(pkt.Data().AsRange().ToSlice()),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{f.stack, id.RemoteAddress, id.RemotePort, netProto}
},
)
return true
}
type UDPBackWriter struct {
stack *stack.Stack
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *UDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
tcpip.Address(destination.Addr.AsSlice()),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return wrapStackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: bufferv2.MakeWithData(buffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(header.ChecksumCombine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().AsRange().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return wrapStackError(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}

View File

@@ -50,6 +50,10 @@ func (p TCPPacket) SetChecksum(sum [2]byte) {
p[17] = sum[1]
}
func (p TCPPacket) OffloadChecksum() {
p.SetChecksum(zeroChecksum)
}
func (p TCPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))

View File

@@ -45,6 +45,10 @@ func (p UDPPacket) SetChecksum(sum [2]byte) {
p[7] = sum[1]
}
func (p UDPPacket) OffloadChecksum() {
p.SetChecksum(zeroChecksum)
}
func (p UDPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))

276
internal/winfw/winfw.go Normal file
View File

@@ -0,0 +1,276 @@
// Copyright (c) 2018 Samuel Melrose
// SPDX-License-Identifier: MIT
// https://github.com/iamacarpet/go-win64api/blob/ef6dbdd6db97301ae08a55eedea773476985a602/firewall.go
//go:build windows
package winfw
import (
"fmt"
"runtime"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
)
// Firewall related API constants.
const (
NET_FW_IP_PROTOCOL_TCP = 6
NET_FW_IP_PROTOCOL_UDP = 17
NET_FW_IP_PROTOCOL_ICMPv4 = 1
NET_FW_IP_PROTOCOL_ICMPv6 = 58
NET_FW_IP_PROTOCOL_ANY = 256
NET_FW_RULE_DIR_IN = 1
NET_FW_RULE_DIR_OUT = 2
NET_FW_ACTION_BLOCK = 0
NET_FW_ACTION_ALLOW = 1
// NET_FW_PROFILE2_CURRENT is not real API constant, just helper used in FW functions.
// It can mean one profile or multiple (even all) profiles. It depends on which profiles
// are currently in use. Every active interface can have it's own profile. F.e.: Public for Wifi,
// Domain for VPN, and Private for LAN. All at the same time.
NET_FW_PROFILE2_CURRENT = 0
NET_FW_PROFILE2_DOMAIN = 1
NET_FW_PROFILE2_PRIVATE = 2
NET_FW_PROFILE2_PUBLIC = 4
NET_FW_PROFILE2_ALL = 2147483647
)
// Firewall Rule Groups
// Use this magical strings instead of group names. It will work on all language Windows versions.
// You can find more string locations here:
// https://windows10dll.nirsoft.net/firewallapi_dll.html
const (
NET_FW_FILE_AND_PRINTER_SHARING = "@FirewallAPI.dll,-28502"
NET_FW_REMOTE_DESKTOP = "@FirewallAPI.dll,-28752"
)
// FWRule represents Firewall Rule.
type FWRule struct {
Name, Description, ApplicationName, ServiceName string
LocalPorts, RemotePorts string
// LocalAddresses, RemoteAddresses are always returned with netmask, f.e.:
// `10.10.1.1/255.255.255.0`
LocalAddresses, RemoteAddresses string
// ICMPTypesAndCodes is string. You can find define multiple codes separated by ":" (colon).
// Types are listed here:
// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
// So to allow ping set it to:
// "0"
ICMPTypesAndCodes string
Grouping string
// InterfaceTypes can be:
// "LAN", "Wireless", "RemoteAccess", "All"
// You can add multiple deviding with comma:
// "LAN, Wireless"
InterfaceTypes string
Protocol, Direction, Action, Profiles int32
Enabled, EdgeTraversal bool
}
// FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters.
// You probably do not want to use this, as function allows to create any rule, even opening all ports
// in given profile. So use with caution.
func FirewallRuleAddAdvanced(rule FWRule) (bool, error) {
return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName,
rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes,
rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
}
// firewallRuleAdd is universal function to add all kinds of rules.
func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) {
if name == "" {
return false, fmt.Errorf("empty FW Rule name, name is mandatory")
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return false, err
}
defer firewallAPIRelease(u, fwPolicy)
if profile == NET_FW_PROFILE2_CURRENT {
currentProfiles, err := oleutil.GetProperty(fwPolicy, "CurrentProfileTypes")
if err != nil {
return false, fmt.Errorf("Failed to get CurrentProfiles: %s", err)
}
profile = currentProfiles.Value().(int32)
}
unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules")
if err != nil {
return false, fmt.Errorf("Failed to get Rules: %s", err)
}
rules := unknownRules.ToIDispatch()
if ok, err := FirewallRuleExistsByName(rules, name); err != nil {
return false, fmt.Errorf("Error while checking rules for duplicate: %s", err)
} else if ok {
return false, nil
}
unknown2, err := oleutil.CreateObject("HNetCfg.FWRule")
if err != nil {
return false, fmt.Errorf("Error creating Rule object: %s", err)
}
defer unknown2.Release()
fwRule, err := unknown2.QueryInterface(ole.IID_IDispatch)
if err != nil {
return false, fmt.Errorf("Error creating Rule object (2): %s", err)
}
defer fwRule.Release()
if _, err := oleutil.PutProperty(fwRule, "Name", name); err != nil {
return false, fmt.Errorf("Error setting property (Name) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Description", description); err != nil {
return false, fmt.Errorf("Error setting property (Description) of Rule: %s", err)
}
if appPath != "" {
if _, err := oleutil.PutProperty(fwRule, "Applicationname", appPath); err != nil {
return false, fmt.Errorf("Error setting property (Applicationname) of Rule: %s", err)
}
}
if serviceName != "" {
if _, err := oleutil.PutProperty(fwRule, "ServiceName", serviceName); err != nil {
return false, fmt.Errorf("Error setting property (ServiceName) of Rule: %s", err)
}
}
if protocol != 0 {
if _, err := oleutil.PutProperty(fwRule, "Protocol", protocol); err != nil {
return false, fmt.Errorf("Error setting property (Protocol) of Rule: %s", err)
}
}
if icmpTypes != "" {
if _, err := oleutil.PutProperty(fwRule, "IcmpTypesAndCodes", icmpTypes); err != nil {
return false, fmt.Errorf("Error setting property (IcmpTypesAndCodes) of Rule: %s", err)
}
}
if ports != "" {
if _, err := oleutil.PutProperty(fwRule, "LocalPorts", ports); err != nil {
return false, fmt.Errorf("Error setting property (LocalPorts) of Rule: %s", err)
}
}
if remotePorts != "" {
if _, err := oleutil.PutProperty(fwRule, "RemotePorts", remotePorts); err != nil {
return false, fmt.Errorf("Error setting property (RemotePorts) of Rule: %s", err)
}
}
if localAddresses != "" {
if _, err := oleutil.PutProperty(fwRule, "LocalAddresses", localAddresses); err != nil {
return false, fmt.Errorf("Error setting property (LocalAddresses) of Rule: %s", err)
}
}
if remoteAddresses != "" {
if _, err := oleutil.PutProperty(fwRule, "RemoteAddresses", remoteAddresses); err != nil {
return false, fmt.Errorf("Error setting property (RemoteAddresses) of Rule: %s", err)
}
}
if direction != 0 {
if _, err := oleutil.PutProperty(fwRule, "Direction", direction); err != nil {
return false, fmt.Errorf("Error setting property (Direction) of Rule: %s", err)
}
}
if _, err := oleutil.PutProperty(fwRule, "Enabled", enabled); err != nil {
return false, fmt.Errorf("Error setting property (Enabled) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Grouping", group); err != nil {
return false, fmt.Errorf("Error setting property (Grouping) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil {
return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil {
return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err)
}
if edgeTraversal {
if _, err := oleutil.PutProperty(fwRule, "EdgeTraversal", edgeTraversal); err != nil {
return false, fmt.Errorf("Error setting property (EdgeTraversal) of Rule: %s", err)
}
}
if _, err := oleutil.CallMethod(rules, "Add", fwRule); err != nil {
return false, fmt.Errorf("Error adding Rule: %s", err)
}
return true, nil
}
func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
enumProperty, err := rules.GetProperty("_NewEnum")
if err != nil {
return false, fmt.Errorf("Failed to get enumeration property on Rules: %s", err)
}
defer enumProperty.Clear()
enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
if err != nil {
return false, fmt.Errorf("Failed to cast enum to correct type: %s", err)
}
if enum == nil {
return false, fmt.Errorf("can't get IEnumVARIANT, enum is nil")
}
for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
if err != nil {
return false, fmt.Errorf("Failed to seek next Rule item: %s", err)
}
t, err := func() (bool, error) {
item := itemRaw.ToIDispatch()
defer item.Release()
if item, err := oleutil.GetProperty(item, "Name"); err != nil {
return false, fmt.Errorf("Failed to get Property (Name) of Rule")
} else if item.ToString() == name {
return true, nil
}
return false, nil
}()
if err != nil {
return false, err
} else if t {
return true, nil
}
}
return false, nil
}
// firewallAPIInit initialize common fw api.
// then:
// dispatch firewallAPIRelease(u, fwp)
func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err != nil {
return nil, nil, fmt.Errorf("Failed to initialize COM: %s", err)
}
unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
if err != nil {
return nil, nil, fmt.Errorf("Failed to create FwPolicy Object: %s", err)
}
fwPolicy, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
unknown.Release()
return nil, nil, fmt.Errorf("Failed to create FwPolicy Object (2): %s", err)
}
return unknown, fwPolicy, nil
}
// firewallAPIRelease cleans memory.
func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
fwp.Release()
u.Release()
ole.CoUninitialize()
}

View File

@@ -385,3 +385,25 @@ func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []st
func (luid LUID) FlushDNS(family AddressFamily) error {
return luid.SetDNS(family, nil, nil)
}
func (luid LUID) DisableDNSRegistration() error {
guid, err := luid.GUID()
if err != nil {
return err
}
dnsInterfaceSettings := &DnsInterfaceSettings{
Version: DnsInterfaceSettingsVersion1,
Flags: DnsInterfaceSettingsFlagRegistrationEnabled,
RegistrationEnabled: 0,
}
// For >= Windows 10 1809
err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
return err
}
// For < Windows 10 1809
return luid.fallbackDisableDNSRegistration()
}

View File

@@ -51,10 +51,11 @@ func runNetsh(cmds []string) error {
}
const (
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateDisableRegistration = "interface ipv6 set dnsservers name=%d register=none"
)
func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
@@ -106,3 +107,13 @@ func (luid LUID) fallbackSetDNSDomain(domain string) error {
key.Close()
return err
}
func (luid LUID) fallbackDisableDNSRegistration() error {
// the DNS registration setting is shared for both IPv4 and IPv6
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return err
}
cmd := fmt.Sprintf(netshCmdTemplateDisableRegistration, ipif.InterfaceIndex)
return runNetsh([]string{cmd})
}

149
lwip.go
View File

@@ -1,149 +0,0 @@
//go:build with_lwip
package tun
import (
"context"
"net"
"net/netip"
"os"
lwip "github.com/sagernet/go-tun2socks/core"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type LWIP struct {
ctx context.Context
tun Tun
tunMtu uint32
udpTimeout int64
handler Handler
stack lwip.LWIPStack
udpNat *udpnat.Service[netip.AddrPort]
}
func NewLWIP(
options StackOptions,
) (Stack, error) {
return &LWIP{
ctx: options.Context,
tun: options.Tun,
tunMtu: options.MTU,
handler: options.Handler,
stack: lwip.NewLWIPStack(),
udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler),
}, nil
}
func (l *LWIP) Start() error {
lwip.RegisterTCPConnHandler(l)
lwip.RegisterUDPConnHandler(l)
lwip.RegisterOutputFn(l.tun.Write)
go l.loopIn()
return nil
}
func (l *LWIP) loopIn() {
if winTun, isWintun := l.tun.(WinTun); isWintun {
l.loopInWintun(winTun)
return
}
mtu := int(l.tunMtu) + PacketOffset
_buffer := buf.StackNewSize(mtu)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
data := buffer.FreeBytes()
for {
n, err := l.tun.Read(data)
if err != nil {
return
}
_, err = l.stack.Write(data[PacketOffset:n])
if err != nil {
if err.Error() == "stack closed" {
return
}
l.handler.NewError(context.Background(), err)
}
}
}
func (l *LWIP) loopInWintun(tun WinTun) {
for {
packet, release, err := tun.ReadPacket()
if err != nil {
return
}
_, err = l.stack.Write(packet)
release()
if err != nil {
if err.Error() == "stack closed" {
return
}
l.handler.NewError(context.Background(), err)
}
}
}
func (l *LWIP) Close() error {
lwip.RegisterTCPConnHandler(nil)
lwip.RegisterUDPConnHandler(nil)
lwip.RegisterOutputFn(func(bytes []byte) (int, error) {
return 0, os.ErrClosed
})
return l.stack.Close()
}
func (l *LWIP) Handle(conn net.Conn) error {
lAddr := conn.LocalAddr()
rAddr := conn.RemoteAddr()
if lAddr == nil || rAddr == nil {
conn.Close()
return nil
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
hErr := l.handler.NewConnection(l.ctx, conn, metadata)
if hErr != nil {
conn.(lwip.TCPConn).Abort()
}
}()
return nil
}
func (l *LWIP) ReceiveTo(conn lwip.UDPConn, data []byte, addr M.Socksaddr) error {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = conn.LocalAddr()
upstreamMetadata.Destination = addr
l.udpNat.NewPacket(
l.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(data).ToOwned(),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &LWIPUDPBackWriter{conn}
},
)
return nil
}
type LWIPUDPBackWriter struct {
conn lwip.UDPConn
}
func (w *LWIPUDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
return common.Error(w.conn.WriteFrom(buffer.Bytes(), destination))
}
func (w *LWIPUDPBackWriter) Close() error {
return w.conn.Close()
}

View File

@@ -1,11 +0,0 @@
//go:build !with_lwip
package tun
import E "github.com/sagernet/sing/common/exceptions"
func NewLWIP(
options StackOptions,
) (Stack, error) {
return nil, E.New(`LWIP is not included in this build, rebuild with -tags with_lwip`)
}

View File

@@ -10,13 +10,14 @@ import (
var ErrNoRoute = E.New("no route to internet")
type (
NetworkUpdateCallback = func() error
DefaultInterfaceUpdateCallback = func(event int) error
NetworkUpdateCallback = func()
DefaultInterfaceUpdateCallback = func(event int)
)
const (
EventInterfaceUpdate = 1
EventAndroidVPNUpdate = 2
EventNoRoute = 4
)
type NetworkUpdateMonitor interface {
@@ -24,7 +25,6 @@ type NetworkUpdateMonitor interface {
Close() error
RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback]
UnregisterCallback(element *list.Element[NetworkUpdateCallback])
E.Handler
}
type DefaultInterfaceMonitor interface {
@@ -32,6 +32,7 @@ type DefaultInterfaceMonitor interface {
Close() error
DefaultInterfaceName(destination netip.Addr) string
DefaultInterfaceIndex(destination netip.Addr) int
DefaultInterface(destination netip.Addr) (string, int)
OverrideAndroidVPN() bool
AndroidVPNEnabled() bool
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
@@ -39,5 +40,6 @@ type DefaultInterfaceMonitor interface {
}
type DefaultInterfaceMonitorOptions struct {
OverrideAndroidVPN bool
OverrideAndroidVPN bool
UnderNetworkExtension bool
}

View File

@@ -1,16 +1,15 @@
package tun
import (
"context"
"net"
"net/netip"
"os"
"sync"
"syscall"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/net/route"
@@ -18,114 +17,152 @@ import (
)
type networkUpdateMonitor struct {
errorHandler E.Handler
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
routeSocket *os.File
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
routeSocketFile *os.File
closeOnce sync.Once
done chan struct{}
logger logger.Logger
}
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
errorHandler: errorHandler,
logger: logger,
done: make(chan struct{}),
}, nil
}
func (m *networkUpdateMonitor) Start() error {
go m.loopUpdate()
return nil
}
func (m *networkUpdateMonitor) loopUpdate() {
for {
select {
case <-m.done:
return
default:
}
err := m.loopUpdate0()
if err != nil {
m.logger.Error("listen network update: ", err)
return
}
}
}
func (m *networkUpdateMonitor) loopUpdate0() error {
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
if err != nil {
return err
}
err = unix.SetNonblock(routeSocket, true)
if err != nil {
unix.Close(routeSocket)
return err
}
m.routeSocket = os.NewFile(uintptr(routeSocket), "route")
go m.loopUpdate()
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
defer routeSocketFile.Close()
m.routeSocketFile = routeSocketFile
m.loopUpdate1(routeSocketFile)
return nil
}
func (m *networkUpdateMonitor) loopUpdate() {
rawConn, err := m.routeSocket.SyscallConn()
func (m *networkUpdateMonitor) loopUpdate1(routeSocketFile *os.File) {
buffer := buf.NewPacket()
defer buffer.Release()
done := make(chan struct{})
go func() {
select {
case <-m.done:
routeSocketFile.Close()
case <-done:
}
}()
n, err := routeSocketFile.Read(buffer.FreeBytes())
close(done)
if err != nil {
m.errorHandler.NewError(context.Background(), E.Cause(err, "create raw route connection"))
return
}
for {
var innerErr error
err = rawConn.Read(func(fd uintptr) (done bool) {
var msg [2048]byte
_, innerErr = unix.Read(int(fd), msg[:])
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
if err != nil {
break
}
m.emit()
buffer.Truncate(n)
messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.Bytes())
if err != nil {
return
}
if err != syscall.EAGAIN {
m.errorHandler.NewError(context.Background(), E.Cause(err, "read route message"))
for _, message := range messages {
if _, isRouteMessage := message.(*route.RouteMessage); isRouteMessage {
m.emit()
return
}
}
}
func (m *networkUpdateMonitor) Close() error {
return common.Close(common.PtrOrNil(m.routeSocket))
m.closeOnce.Do(func() {
close(m.done)
})
return nil
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return err
}
routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage)
if err != nil {
return err
}
var defaultInterface *net.Interface
for _, rawRouteMessage := range routeMessages {
routeMessage := rawRouteMessage.(*route.RouteMessage)
if len(routeMessage.Addrs) <= unix.RTAX_NETMASK {
continue
}
destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr)
if !isIPv4Destination {
continue
}
if destination.IP != netip.IPv4Unspecified().As4() {
continue
}
mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr)
if !isIPv4Mask {
continue
}
ones, _ := net.IPMask(mask.IP[:]).Size()
if ones != 0 {
continue
}
routeInterface, err := net.InterfaceByIndex(routeMessage.Index)
if err != nil {
return err
}
if routeMessage.Flags&unix.RTF_UP == 0 {
continue
}
if routeMessage.Flags&unix.RTF_GATEWAY == 0 {
continue
}
if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
continue
}
defaultInterface = routeInterface
break
}
if defaultInterface == nil {
var (
defaultInterface *net.Interface
err error
)
if m.options.UnderNetworkExtension {
defaultInterface, err = getDefaultInterfaceBySocket()
if err != nil {
return err
}
} else {
ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return err
}
routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage)
if err != nil {
return err
}
for _, rawRouteMessage := range routeMessages {
routeMessage := rawRouteMessage.(*route.RouteMessage)
if len(routeMessage.Addrs) <= unix.RTAX_NETMASK {
continue
}
destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr)
if !isIPv4Destination {
continue
}
if destination.IP != netip.IPv4Unspecified().As4() {
continue
}
mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr)
if !isIPv4Mask {
continue
}
ones, _ := net.IPMask(mask.IP[:]).Size()
if ones != 0 {
continue
}
routeInterface, err := net.InterfaceByIndex(routeMessage.Index)
if err != nil {
return err
}
if routeMessage.Flags&unix.RTF_UP == 0 {
continue
}
if routeMessage.Flags&unix.RTF_GATEWAY == 0 {
continue
}
if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
// continue
}
defaultInterface = routeInterface
break
}
}
if defaultInterface == nil {
return ErrNoRoute
}
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
@@ -149,6 +186,8 @@ func getDefaultInterfaceBySocket() (*net.Interface, error) {
Port: 80,
})
result := make(chan netip.Addr, 1)
done := make(chan struct{})
defer close(done)
go func() {
for {
sockname, sockErr := unix.Getsockname(socketFd)
@@ -161,8 +200,13 @@ func getDefaultInterfaceBySocket() (*net.Interface, error) {
}
addr := netip.AddrFrom4(sockaddr.Addr)
if addr.IsUnspecified() {
time.Sleep(time.Millisecond)
continue
select {
case <-done:
break
default:
time.Sleep(10 * time.Millisecond)
continue
}
}
result <- addr
break
@@ -172,7 +216,7 @@ func getDefaultInterfaceBySocket() (*net.Interface, error) {
select {
case selectedAddr = <-result:
case <-time.After(time.Second):
return nil, os.ErrDeadlineExceeded
return nil, nil
}
interfaces, err := net.Interfaces()
if err != nil {

View File

@@ -2,30 +2,56 @@ package tun
import (
"os"
"runtime"
"sync"
"time"
"github.com/sagernet/netlink"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/unix"
)
type networkUpdateMonitor struct {
routeUpdate chan netlink.RouteUpdate
linkUpdate chan netlink.LinkUpdate
close chan struct{}
errorHandler E.Handler
routeUpdate chan netlink.RouteUpdate
linkUpdate chan netlink.LinkUpdate
close chan struct{}
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
logger logger.Logger
}
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
routeUpdate: make(chan netlink.RouteUpdate, 2),
linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}),
errorHandler: errorHandler,
}, nil
var ErrNetlinkBanned = E.New(
"netlink socket in Android is banned by Google, " +
"use the root or system (ADB) user to run sing-box, " +
"or switch to the sing-box Adnroid graphical interface client",
)
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
monitor := &networkUpdateMonitor{
routeUpdate: make(chan netlink.RouteUpdate, 2),
linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}),
logger: logger,
}
// check is netlink banned by google
if runtime.GOOS == "android" {
netlinkSocket, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_ROUTE)
if err != nil {
return nil, ErrNetlinkBanned
}
err = unix.Bind(netlinkSocket, &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
})
unix.Close(netlinkSocket)
if err != nil {
return nil, ErrNetlinkBanned
}
}
return monitor, nil
}
func (m *networkUpdateMonitor) Start() error {
@@ -42,6 +68,9 @@ func (m *networkUpdateMonitor) Start() error {
}
func (m *networkUpdateMonitor) loopUpdate() {
const minDuration = time.Second
timer := time.NewTimer(minDuration)
defer timer.Stop()
for {
select {
case <-m.close:
@@ -50,6 +79,12 @@ func (m *networkUpdateMonitor) loopUpdate() {
case <-m.linkUpdate:
}
m.emit()
select {
case <-m.close:
return
case <-timer.C:
timer.Reset(minDuration)
}
}
}

View File

@@ -4,7 +4,6 @@ package tun
import (
"github.com/sagernet/netlink"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
@@ -37,5 +36,5 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
m.emit(EventInterfaceUpdate)
return nil
}
return E.New("no route to internet")
return ErrNoRoute
}

View File

@@ -5,13 +5,13 @@ package tun
import (
"os"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
return nil, os.ErrInvalid
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
return nil, os.ErrInvalid
}

View File

@@ -3,14 +3,14 @@
package tun
import (
"context"
"errors"
"net"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
)
@@ -32,27 +32,23 @@ func (m *networkUpdateMonitor) emit() {
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback()
if err != nil {
m.NewError(context.Background(), err)
}
callback()
}
}
func (m *networkUpdateMonitor) NewError(ctx context.Context, err error) {
m.errorHandler.NewError(ctx, err)
}
type defaultInterfaceMonitor struct {
options DefaultInterfaceMonitorOptions
networkAddresses []networkAddress
defaultInterfaceName string
defaultInterfaceIndex int
androidVPNEnabled bool
noRoute bool
networkMonitor NetworkUpdateMonitor
checkUpdateTimer *time.Timer
element *list.Element[NetworkUpdateCallback]
access sync.Mutex
callbacks list.List[DefaultInterfaceUpdateCallback]
logger logger.Logger
}
type networkAddress struct {
@@ -61,30 +57,47 @@ type networkAddress struct {
addresses []netip.Prefix
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
return &defaultInterfaceMonitor{
options: options,
networkMonitor: networkMonitor,
defaultInterfaceIndex: -1,
logger: logger,
}, nil
}
func (m *defaultInterfaceMonitor) Start() error {
err := m.checkUpdate()
if err != nil {
m.networkMonitor.NewError(context.Background(), err)
}
_ = m.checkUpdate()
m.element = m.networkMonitor.RegisterCallback(m.delayCheckUpdate)
return nil
}
func (m *defaultInterfaceMonitor) delayCheckUpdate() error {
time.Sleep(time.Second)
func (m *defaultInterfaceMonitor) delayCheckUpdate() {
if m.checkUpdateTimer == nil {
m.checkUpdateTimer = time.AfterFunc(time.Second, m.postCheckUpdate)
} else {
m.checkUpdateTimer.Reset(time.Second)
}
}
func (m *defaultInterfaceMonitor) postCheckUpdate() {
err := m.updateInterfaces()
if err != nil {
m.networkMonitor.NewError(context.Background(), E.Cause(err, "update interfaces"))
m.logger.Error("update interfaces: ", err)
}
err = m.checkUpdate()
if errors.Is(err, ErrNoRoute) {
if !m.noRoute {
m.noRoute = true
m.defaultInterfaceName = ""
m.defaultInterfaceIndex = -1
m.emit(EventNoRoute)
}
} else if err != nil {
m.logger.Error("check interface: ", err)
} else {
m.noRoute = false
}
return m.checkUpdate()
}
func (m *defaultInterfaceMonitor) updateInterfaces() error {
@@ -130,9 +143,6 @@ func (m *defaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) s
}
}
}
if m.defaultInterfaceIndex == -1 {
m.checkUpdate()
}
return m.defaultInterfaceName
}
@@ -144,12 +154,20 @@ func (m *defaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr)
}
}
}
if m.defaultInterfaceIndex == -1 {
m.checkUpdate()
}
return m.defaultInterfaceIndex
}
func (m *defaultInterfaceMonitor) DefaultInterface(destination netip.Addr) (string, int) {
for _, address := range m.networkAddresses {
for _, prefix := range address.addresses {
if prefix.Contains(destination) {
return address.interfaceName, address.interfaceIndex
}
}
}
return m.defaultInterfaceName, m.defaultInterfaceIndex
}
func (m *defaultInterfaceMonitor) OverrideAndroidVPN() bool {
return m.options.OverrideAndroidVPN
}
@@ -175,9 +193,6 @@ func (m *defaultInterfaceMonitor) emit(event int) {
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback(event)
if err != nil {
m.networkMonitor.NewError(context.Background(), err)
}
callback(event)
}
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/sagernet/sing-tun/internal/winipcfg"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/windows"
@@ -17,11 +18,12 @@ type networkUpdateMonitor struct {
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
logger logger.Logger
}
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
errorHandler: errorHandler,
logger: logger,
}, nil
}

View File

@@ -1,92 +0,0 @@
package tun
import (
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
type ActionType = uint8
const (
ActionTypeUnknown ActionType = iota
ActionTypeReturn
ActionTypeBlock
ActionTypeDirect
)
func ParseActionType(action string) (ActionType, error) {
switch action {
case "return":
return ActionTypeReturn, nil
case "block":
return ActionTypeBlock, nil
case "direct":
return ActionTypeDirect, nil
default:
return 0, E.New("unknown action: ", action)
}
}
func ActionTypeName(actionType ActionType) (string, error) {
switch actionType {
case ActionTypeUnknown:
return "", nil
case ActionTypeReturn:
return "return", nil
case ActionTypeBlock:
return "block", nil
case ActionTypeDirect:
return "direct", nil
default:
return "", E.New("unknown action: ", actionType)
}
}
type RouteSession struct {
IPVersion uint8
Network uint8
Source netip.AddrPort
Destination netip.AddrPort
}
type RouteContext interface {
WritePacket(packet []byte) error
}
type Router interface {
RouteConnection(session RouteSession, context RouteContext) RouteAction
}
type RouteAction interface {
ActionType() ActionType
Timeout() bool
}
type ActionReturn struct{}
func (r *ActionReturn) ActionType() ActionType {
return ActionTypeReturn
}
func (r *ActionReturn) Timeout() bool {
return false
}
type ActionBlock struct{}
func (r *ActionBlock) ActionType() ActionType {
return ActionTypeBlock
}
func (r *ActionBlock) Timeout() bool {
return false
}
type ActionDirect struct {
DirectDestination
}
func (r *ActionDirect) ActionType() ActionType {
return ActionTypeDirect
}

View File

@@ -1,16 +0,0 @@
//go:build with_gvisor
package tun
import (
"github.com/sagernet/sing/common/buf"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
type DirectDestination interface {
WritePacket(buffer *buf.Buffer) error
WritePacketBuffer(buffer *stack.PacketBuffer) error
Close() error
Timeout() bool
}

View File

@@ -1,32 +0,0 @@
package tun
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/cache"
)
type RouteMapping struct {
status *cache.LruCache[RouteSession, RouteAction]
}
func NewRouteMapping(maxAge int64) *RouteMapping {
return &RouteMapping{
status: cache.New(
cache.WithAge[RouteSession, RouteAction](maxAge),
cache.WithUpdateAgeOnGet[RouteSession, RouteAction](),
cache.WithEvict[RouteSession, RouteAction](func(key RouteSession, conn RouteAction) {
common.Close(conn)
}),
),
}
}
func (m *RouteMapping) Lookup(session RouteSession, constructor func() RouteAction) RouteAction {
action, _ := m.status.LoadOrStore(session, constructor)
if action.Timeout() {
common.Close(action)
action = constructor()
m.status.Store(session, action)
}
return action
}

View File

@@ -1,119 +0,0 @@
package tun
import (
"net/netip"
"sync"
"github.com/sagernet/sing-tun/internal/clashtcpip"
)
type NatMapping struct {
access sync.RWMutex
sessions map[RouteSession]RouteContext
ipRewrite bool
}
func NewNatMapping(ipRewrite bool) *NatMapping {
return &NatMapping{
sessions: make(map[RouteSession]RouteContext),
ipRewrite: ipRewrite,
}
}
func (m *NatMapping) CreateSession(session RouteSession, context RouteContext) {
if m.ipRewrite {
session.Source = netip.AddrPort{}
}
m.access.Lock()
m.sessions[session] = context
m.access.Unlock()
}
func (m *NatMapping) DeleteSession(session RouteSession) {
if m.ipRewrite {
session.Source = netip.AddrPort{}
}
m.access.Lock()
delete(m.sessions, session)
m.access.Unlock()
}
func (m *NatMapping) WritePacket(packet []byte) (bool, error) {
var routeSession RouteSession
var ipHdr clashtcpip.IP
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
routeSession.IPVersion = 4
ipHdr = clashtcpip.IPv4Packet(packet)
case 6:
routeSession.IPVersion = 6
ipHdr = clashtcpip.IPv6Packet(packet)
default:
return false, nil
}
routeSession.Network = ipHdr.Protocol()
switch routeSession.Network {
case clashtcpip.TCP:
tcpHdr := clashtcpip.TCPPacket(ipHdr.Payload())
routeSession.Destination = netip.AddrPortFrom(ipHdr.SourceIP(), tcpHdr.SourcePort())
if !m.ipRewrite {
routeSession.Source = netip.AddrPortFrom(ipHdr.DestinationIP(), tcpHdr.DestinationPort())
}
case clashtcpip.UDP:
udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
routeSession.Destination = netip.AddrPortFrom(ipHdr.SourceIP(), udpHdr.SourcePort())
if !m.ipRewrite {
routeSession.Source = netip.AddrPortFrom(ipHdr.DestinationIP(), udpHdr.DestinationPort())
}
default:
routeSession.Destination = netip.AddrPortFrom(ipHdr.SourceIP(), 0)
if !m.ipRewrite {
routeSession.Source = netip.AddrPortFrom(ipHdr.DestinationIP(), 0)
}
}
m.access.RLock()
context, loaded := m.sessions[routeSession]
m.access.RUnlock()
if !loaded {
return false, nil
}
return true, context.WritePacket(packet)
}
type NatWriter struct {
inet4Address netip.Addr
inet6Address netip.Addr
}
func NewNatWriter(inet4Address netip.Addr, inet6Address netip.Addr) *NatWriter {
return &NatWriter{
inet4Address: inet4Address,
inet6Address: inet6Address,
}
}
func (w *NatWriter) RewritePacket(packet []byte) {
var ipHdr clashtcpip.IP
var bindAddr netip.Addr
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
ipHdr = clashtcpip.IPv4Packet(packet)
bindAddr = w.inet4Address
case 6:
ipHdr = clashtcpip.IPv6Packet(packet)
bindAddr = w.inet6Address
default:
return
}
ipHdr.SetSourceIP(bindAddr)
switch ipHdr.Protocol() {
case clashtcpip.TCP:
tcpHdr := clashtcpip.TCPPacket(ipHdr.Payload())
tcpHdr.ResetChecksum(ipHdr.PseudoSum())
case clashtcpip.UDP:
udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
udpHdr.ResetChecksum(ipHdr.PseudoSum())
default:
}
ipHdr.ResetChecksum()
}

View File

@@ -1,41 +0,0 @@
//go:build with_gvisor
package tun
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
func (w *NatWriter) RewritePacketBuffer(packetBuffer *stack.PacketBuffer) {
var bindAddr tcpip.Address
if packetBuffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
bindAddr = tcpip.Address(w.inet4Address.AsSlice())
} else {
bindAddr = tcpip.Address(w.inet6Address.AsSlice())
}
var ipHdr header.Network
switch packetBuffer.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
ipHdr = header.IPv4(packetBuffer.NetworkHeader().Slice())
case header.IPv6ProtocolNumber:
ipHdr = header.IPv6(packetBuffer.NetworkHeader().Slice())
default:
return
}
oldAddr := ipHdr.SourceAddress()
if checksumHdr, needChecksum := ipHdr.(header.ChecksummableNetwork); needChecksum {
checksumHdr.SetSourceAddressWithChecksumUpdate(bindAddr)
} else {
ipHdr.SetSourceAddress(bindAddr)
}
switch packetBuffer.TransportProtocolNumber {
case header.TCPProtocolNumber:
tcpHdr := header.TCP(packetBuffer.TransportHeader().Slice())
tcpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
case header.UDPProtocolNumber:
udpHdr := header.UDP(packetBuffer.TransportHeader().Slice())
udpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
}
}

View File

@@ -1,11 +0,0 @@
//go:build !with_gvisor
package tun
import "github.com/sagernet/sing/common/buf"
type DirectDestination interface {
WritePacket(buffer *buf.Buffer) error
Close() error
Timeout() bool
}

View File

@@ -2,6 +2,8 @@ package tun
import (
"context"
"encoding/binary"
"net"
"net/netip"
"github.com/sagernet/sing/common/control"
@@ -17,16 +19,13 @@ type Stack interface {
type StackOptions struct {
Context context.Context
Tun Tun
Name string
MTU uint32
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
TunOptions Options
EndpointIndependentNat bool
UDPTimeout int64
Router Router
Handler Handler
Logger logger.Logger
ForwarderBindInterface bool
IncludeAllNetworks bool
InterfaceFinder control.InterfaceFinder
}
@@ -36,14 +35,36 @@ func NewStack(
) (Stack, error) {
switch stack {
case "":
return NewSystem(options)
if options.IncludeAllNetworks {
return NewGVisor(options)
} else if WithGVisor && !options.TunOptions.GSO {
return NewMixed(options)
} else {
return NewSystem(options)
}
case "gvisor":
return NewGVisor(options)
case "mixed":
if options.IncludeAllNetworks {
return nil, ErrIncludeAllNetworks
}
return NewMixed(options)
case "system":
if options.IncludeAllNetworks {
return nil, ErrIncludeAllNetworks
}
return NewSystem(options)
case "lwip":
return NewLWIP(options)
default:
return nil, E.New("unknown stack: ", stack)
}
}
func BroadcastAddr(inet4Address []netip.Prefix) netip.Addr {
if len(inet4Address) == 0 {
return netip.Addr{}
}
prefix := inet4Address[0]
var broadcastAddr [4]byte
binary.BigEndian.PutUint32(broadcastAddr[:], binary.BigEndian.Uint32(prefix.Masked().Addr().AsSlice())|^binary.BigEndian.Uint32(net.CIDRMask(prefix.Bits(), 32)))
return netip.AddrFrom4(broadcastAddr)
}

View File

@@ -4,26 +4,24 @@ package tun
import (
"context"
"net"
"syscall"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const WithGVisor = true
@@ -33,15 +31,13 @@ const defaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
tun GVisorTun
tunMtu uint32
endpointIndependentNat bool
udpTimeout int64
router Router
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
routeMapping *RouteMapping
}
type GVisorTun interface {
@@ -60,16 +56,12 @@ func NewGVisor(
gStack := &GVisor{
ctx: options.Context,
tun: gTun,
tunMtu: options.MTU,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
router: options.Router,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
if gStack.router != nil {
gStack.routeMapping = NewRouteMapping(options.UDPTimeout)
}
return gStack, nil
}
@@ -78,44 +70,11 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint)
if tErr != nil {
return E.New("create nic: ", wrapStackError(tErr))
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
handshakeCtx, cancel := context.WithCancel(context.Background())
@@ -155,44 +114,7 @@ func (t *GVisor) Start() error {
}
}()
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
if t.router != nil {
var routeSession RouteSession
routeSession.Network = syscall.IPPROTO_TCP
var ipHdr header.Network
if buffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
routeSession.IPVersion = 4
ipHdr = header.IPv4(buffer.NetworkHeader().Slice())
} else {
routeSession.IPVersion = 6
ipHdr = header.IPv6(buffer.NetworkHeader().Slice())
}
tcpHdr := header.TCP(buffer.TransportHeader().Slice())
routeSession.Source = M.AddrPortFrom(net.IP(ipHdr.SourceAddress()), tcpHdr.SourcePort())
routeSession.Destination = M.AddrPortFrom(net.IP(ipHdr.DestinationAddress()), tcpHdr.DestinationPort())
action := t.routeMapping.Lookup(routeSession, func() RouteAction {
if routeSession.IPVersion == 4 {
return t.router.RouteConnection(routeSession, &systemTCPDirectPacketWriter4{t.tun, routeSession.Source})
} else {
return t.router.RouteConnection(routeSession, &systemTCPDirectPacketWriter6{t.tun, routeSession.Source})
}
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return true
case *ActionDirect:
buffer.IncRef()
err = actionType.WritePacketBuffer(buffer)
if err != nil {
t.logger.Trace("route gvisor tcp packet: ", err)
}
return true
}
}
return tcpForwarder.HandlePacket(id, buffer)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
if !t.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
@@ -200,61 +122,26 @@ func (t *GVisor) Start() error {
if err != nil {
return
}
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
udpConn := gonet.NewUDPConn(&wq, endpoint)
lAddr := udpConn.RemoteAddr()
rAddr := udpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
endpoint.Abort()
return
}
gConn := &gUDPConn{UDPConn: udpConn}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(&gUDPConn{udpConn}), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(t.udpTimeout)*time.Second)
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
if t.router != nil {
var routeSession RouteSession
routeSession.Network = syscall.IPPROTO_UDP
var ipHdr header.Network
if buffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
routeSession.IPVersion = 4
ipHdr = header.IPv4(buffer.NetworkHeader().Slice())
} else {
routeSession.IPVersion = 6
ipHdr = header.IPv6(buffer.NetworkHeader().Slice())
}
udpHdr := header.UDP(buffer.TransportHeader().Slice())
routeSession.Source = M.AddrPortFrom(net.IP(ipHdr.SourceAddress()), udpHdr.SourcePort())
routeSession.Destination = M.AddrPortFrom(net.IP(ipHdr.DestinationAddress()), udpHdr.DestinationPort())
action := t.routeMapping.Lookup(routeSession, func() RouteAction {
if routeSession.IPVersion == 4 {
return t.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter4{t.tun, routeSession.Source})
} else {
return t.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter6{t.tun, routeSession.Source})
}
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return true
case *ActionDirect:
buffer.IncRef()
err = actionType.WritePacketBuffer(buffer)
if err != nil {
t.logger.Trace("route gvisor udp packet: ", err)
}
return true
}
}
return udpForwarder.HandlePacket(id, buffer)
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
}
@@ -272,3 +159,60 @@ func (t *GVisor) Close() error {
}
return nil
}
func AddressFromAddr(destination netip.Addr) tcpip.Address {
if destination.Is6() {
return tcpip.AddrFrom16(destination.As16())
} else {
return tcpip.AddrFrom4(destination.As4())
}
}
func AddrFromAddress(address tcpip.Address) netip.Addr {
if address.Len() == 16 {
return netip.AddrFrom16(address.As16())
} else {
return netip.AddrFrom4(address.As4())
}
}
func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, ep)
if tErr != nil {
return nil, E.New("create nic: ", wrapStackError(tErr))
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
return ipStack, nil
}

View File

@@ -5,10 +5,9 @@ package tun
import (
"net"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
E "github.com/sagernet/sing/common/exceptions"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
type gTCPConn struct {
@@ -28,28 +27,6 @@ func (c *gTCPConn) Write(b []byte) (n int, err error) {
return
}
type gUDPConn struct {
*gonet.UDPConn
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func wrapStackError(err tcpip.Error) error {
switch err.(type) {
case *tcpip.ErrClosedForSend,

56
stack_gvisor_filter.go Normal file
View File

@@ -0,0 +1,56 @@
//go:build with_gvisor
package tun
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
type LinkEndpointFilter struct {
stack.LinkEndpoint
BroadcastAddress netip.Addr
Writer N.VectorisedWriter
}
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.BroadcastAddress, w.Writer})
}
var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
type networkDispatcherFilter struct {
stack.NetworkDispatcher
broadcastAddress netip.Addr
writer N.VectorisedWriter
}
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
var network header.Network
if protocol == header.IPv4ProtocolNumber {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv4MinimumSize); loaded {
network = header.IPv4(headerPackets)
}
} else {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv6MinimumSize); loaded {
network = header.IPv6(headerPackets)
}
}
if network == nil {
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
return
}
destination := AddrFromAddress(network.DestinationAddress())
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
return
}
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
}

View File

@@ -5,7 +5,7 @@ package tun
import (
"time"
gLog "gvisor.dev/gvisor/pkg/log"
gLog "github.com/sagernet/gvisor/pkg/log"
)
func init() {

View File

@@ -13,3 +13,9 @@ func NewGVisor(
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}
func NewMixed(
options StackOptions,
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}

214
stack_gvisor_udp.go Normal file
View File

@@ -0,0 +1,214 @@
//go:build with_gvisor
package tun
import (
"context"
"errors"
"math"
"net/netip"
"os"
"sync"
"syscall"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort]
// cache
cacheProto tcpip.NetworkProtocolNumber
cacheID stack.TransportEndpointID
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
}
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
if upstreamMetadata.Source.IsIPv4() {
f.cacheProto = header.IPv4ProtocolNumber
} else {
f.cacheProto = header.IPv6ProtocolNumber
}
gBuffer := pkt.Data().ToBuffer()
sBuffer := buf.NewSize(int(gBuffer.Size()))
gBuffer.Apply(func(view *buffer.View) {
sBuffer.Write(view.AsSlice())
})
f.cacheID = id
f.udpNat.NewPacket(
f.ctx,
upstreamMetadata.Source.AddrPort(),
sBuffer,
upstreamMetadata,
f.newUDPConn,
)
return true
}
func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{
stack: f.stack,
source: f.cacheID.RemoteAddress,
sourcePort: f.cacheID.RemotePort,
sourceNetwork: f.cacheProto,
}
}
type UDPBackWriter struct {
access sync.Mutex
stack *stack.Stack
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
if !destination.IsIP() {
return E.Cause(os.ErrInvalid, "invalid destination")
} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
return E.New("send IPv6 packet to IPv4 connection")
}
defer packetBuffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return wrapStackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(checksum.Combine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return wrapStackError(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}
type gUDPConn struct {
*gonet.UDPConn
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Close() error {
return c.UDPConn.Close()
}
func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer, err error) (retErr error) {
if errors.Is(err, syscall.ENETUNREACH) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
}
} else if errors.Is(err, syscall.EHOSTUNREACH) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
}
} else if errors.Is(err, syscall.ECONNREFUSED) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
}
}
return nil
}
func gWriteUnreachable4(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv4WithICMPType) error {
err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
if err != nil {
return wrapStackError(err)
}
return nil
}
func gWriteUnreachable6(gStack *stack.Stack, packet *stack.PacketBuffer, icmpCode stack.RejectIPv6WithICMPType) error {
err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
if err != nil {
return wrapStackError(err)
}
return nil
}

269
stack_mixed.go Normal file
View File

@@ -0,0 +1,269 @@
//go:build with_gvisor
package tun
import (
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun/internal/clashtcpip"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type Mixed struct {
*System
endpointIndependentNat bool
stack *stack.Stack
endpoint *channel.Endpoint
}
func NewMixed(
options StackOptions,
) (Stack, error) {
system, err := NewSystem(options)
if err != nil {
return nil, err
}
return &Mixed{
System: system.(*System),
endpointIndependentNat: options.EndpointIndependentNat,
}, nil
}
func (m *Mixed) Start() error {
err := m.System.start()
if err != nil {
return err
}
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := newGVisorStack(endpoint)
if err != nil {
return err
}
if !m.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
if err != nil {
return
}
udpConn := gonet.NewUDPConn(&wq, endpoint)
lAddr := udpConn.RemoteAddr()
rAddr := udpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
endpoint.Abort()
return
}
gConn := &gUDPConn{UDPConn: udpConn}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewUnbindPacketConnWithAddr(gConn, metadata.Destination), time.Duration(m.udpTimeout)*time.Second)
hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
}
m.stack = ipStack
m.endpoint = endpoint
go m.tunLoop()
go m.packetLoop()
return nil
}
func (m *Mixed) tunLoop() {
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
m.wintunLoop(winTun)
return
}
if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
m.frontHeadroom = linuxTUN.FrontHeadroom()
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
m.batchLoop(linuxTUN, batchSize)
return
}
}
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer)
if err != nil {
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "read packet"))
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
if m.processPacket(packet) {
_, err = m.tun.Write(rawPacket)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}
func (m *Mixed) wintunLoop(winTun WinTun) {
for {
packet, release, err := winTun.ReadPacket()
if err != nil {
return
}
if len(packet) < clashtcpip.IPv4PacketMinLength {
release()
continue
}
if m.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "batch read packet"))
}
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetSize := packetSizes[i]
if packetSize < clashtcpip.IPv4PacketMinLength {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}
func (m *Mixed) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
writeBack, err = m.processIPv4(packet)
case 6:
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
return false
}
return writeBack
}
func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
writeBack = true
destination := packet.DestinationIP()
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
err = m.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
IsForwardedPacket: true,
})
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return
case clashtcpip.ICMP:
err = m.processIPv4ICMP(packet, packet.Payload())
}
return
}
func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
writeBack = true
if !packet.DestinationIP().IsGlobalUnicast() {
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
err = m.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
IsForwardedPacket: true,
})
m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
pkt.DecRef()
case clashtcpip.ICMPv6:
err = m.processIPv6ICMP(packet, packet.Payload())
}
return
}
func (m *Mixed) packetLoop() {
for {
packet := m.endpoint.ReadContext(m.ctx)
if packet == nil {
break
}
bufio.WriteVectorised(m.tun, packet.AsSlices())
packet.DecRef()
}
}
func (m *Mixed) Close() error {
m.endpoint.Attach(nil)
m.stack.Close()
for _, endpoint := range m.stack.CleanupEndpoints() {
endpoint.Abort()
}
return m.System.Close()
}

View File

@@ -18,12 +18,13 @@ import (
"github.com/sagernet/sing/common/udpnat"
)
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
type System struct {
ctx context.Context
tun Tun
tunName string
mtu uint32
router Router
mtu int
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
@@ -32,6 +33,7 @@ type System struct {
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
udpTimeout int64
tcpListener net.Listener
tcpListener6 net.Listener
@@ -39,9 +41,10 @@ type System struct {
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service[netip.AddrPort]
routeMapping *RouteMapping
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
}
type Session struct {
@@ -55,32 +58,29 @@ func NewSystem(options StackOptions) (Stack, error) {
stack := &System{
ctx: options.Context,
tun: options.Tun,
tunName: options.Name,
mtu: options.MTU,
tunName: options.TunOptions.Name,
mtu: int(options.TunOptions.MTU),
udpTimeout: options.UDPTimeout,
router: options.Router,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
inet4Prefixes: options.TunOptions.Inet4Address,
inet6Prefixes: options.TunOptions.Inet6Address,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
}
if stack.router != nil {
stack.routeMapping = NewRouteMapping(options.UDPTimeout)
}
if len(options.Inet4Address) > 0 {
if options.Inet4Address[0].Bits() == 32 {
if len(options.TunOptions.Inet4Address) > 0 {
if options.TunOptions.Inet4Address[0].Bits() == 32 {
return nil, E.New("need one more IPv4 address in first prefix for system stack")
}
stack.inet4ServerAddress = options.Inet4Address[0].Addr()
stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
stack.inet4Address = stack.inet4ServerAddress.Next()
}
if len(options.Inet6Address) > 0 {
if options.Inet6Address[0].Bits() == 128 {
if len(options.TunOptions.Inet6Address) > 0 {
if options.TunOptions.Inet6Address[0].Bits() == 128 {
return nil, E.New("need one more IPv6 address in first prefix for system stack")
}
stack.inet6ServerAddress = options.Inet6Address[0].Addr()
stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
stack.inet6Address = stack.inet6ServerAddress.Next()
}
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
@@ -97,12 +97,25 @@ func (s *System) Close() error {
}
func (s *System) Start() error {
err := s.start()
if err != nil {
return err
}
go s.tunLoop()
return nil
}
func (s *System) start() error {
err := fixWindowsFirewall()
if err != nil {
return E.Cause(err, "fix windows firewall for system stack")
}
var listener net.ListenConfig
if s.bindInterface {
listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
err := control.BindToInterface(s.interfaceFinder, s.tunName, -1)(network, address, conn)
if err != nil {
s.logger.Warn("bind forwarder to interface: ", err)
bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true)
if bindErr != nil {
s.logger.Warn("bind forwarder to interface: ", bindErr)
}
return nil
})
@@ -127,7 +140,6 @@ func (s *System) Start() error {
}
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler)
go s.tunLoop()
return nil
}
@@ -136,30 +148,34 @@ func (s *System) tunLoop() {
s.wintunLoop(winTun)
return
}
_packetBuffer := buf.StackNewSize(int(s.mtu))
defer common.KeepAlive(_packetBuffer)
packetBuffer := common.Dup(_packetBuffer)
defer packetBuffer.Release()
packetSlice := packetBuffer.Slice()
for {
n, err := s.tun.Read(packetSlice)
if err != nil {
if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
s.frontHeadroom = linuxTUN.FrontHeadroom()
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
s.batchLoop(linuxTUN, batchSize)
return
}
}
packetBuffer := make([]byte, s.mtu+PacketOffset)
for {
n, err := s.tun.Read(packetBuffer)
if err != nil {
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "read packet"))
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
packet := packetSlice[PacketOffset:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(packet)
case 6:
err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
if s.processPacket(packet) {
_, err = s.tun.Write(rawPacket)
if err != nil {
s.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}
@@ -174,21 +190,75 @@ func (s *System) wintunLoop(winTun WinTun) {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(packet)
case 6:
err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
if s.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
s.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "batch read packet"))
}
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetSize := packetSizes[i]
if packetSize < clashtcpip.IPv4PacketMinLength {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}
func (s *System) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
writeBack, err = s.processIPv4(packet)
case 6:
writeBack, err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
return false
}
return writeBack
}
func (s *System) acceptLoop(listener net.Listener) {
for {
conn, err := listener.Accept()
@@ -230,37 +300,46 @@ func (s *System) acceptLoop(listener net.Listener) {
}
}
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) (writeBack bool, err error) {
writeBack = true
destination := packet.DestinationIP()
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv4TCP(packet, packet.Payload())
err = s.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
return s.processIPv4UDP(packet, packet.Payload())
writeBack = false
err = s.processIPv4UDP(packet, packet.Payload())
case clashtcpip.ICMP:
return s.processIPv4ICMP(packet, packet.Payload())
default:
return common.Error(s.tun.Write(packet))
err = s.processIPv4ICMP(packet, packet.Payload())
}
return
}
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error {
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) (writeBack bool, err error) {
writeBack = true
if !packet.DestinationIP().IsGlobalUnicast() {
return
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv6TCP(packet, packet.Payload())
err = s.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
return s.processIPv6UDP(packet, packet.Payload())
writeBack = false
err = s.processIPv6UDP(packet, packet.Payload())
case clashtcpip.ICMPv6:
return s.processIPv6ICMP(packet, packet.Payload())
default:
return common.Error(s.tun.Write(packet))
err = s.processIPv6ICMP(packet, packet.Payload())
}
return
}
func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.TCPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return nil
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
@@ -271,37 +350,27 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
if s.router != nil {
session := RouteSession{4, syscall.IPPROTO_TCP, source, destination}
action := s.routeMapping.Lookup(session, func() RouteAction {
return s.router.RouteConnection(session, &systemTCPDirectPacketWriter4{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send ICMP unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv4 tcp packet")
})
}
}
natPort := s.tcpNat.Lookup(source, destination)
packet.SetSourceIP(s.inet4Address)
header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet4ServerAddress)
header.SetDestinationPort(s.tcpPort)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
if !s.txChecksumOffload {
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
} else {
header.OffloadChecksum()
packet.ResetChecksum()
}
return nil
}
func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.TCPPacket) error {
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
return nil
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
@@ -312,30 +381,18 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
if s.router != nil {
session := RouteSession{6, syscall.IPPROTO_TCP, source, destination}
action := s.routeMapping.Lookup(session, func() RouteAction {
return s.router.RouteConnection(session, &systemTCPDirectPacketWriter6{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send RST
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv6 tcp packet")
})
}
}
natPort := s.tcpNat.Lookup(source, destination)
packet.SetSourceIP(s.inet6Address)
header.SetSourcePort(natPort)
packet.SetDestinationIP(s.inet6ServerAddress)
header.SetDestinationPort(s.tcpPort6)
}
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
if !s.txChecksumOffload {
header.ResetChecksum(packet.PseudoSum())
} else {
header.OffloadChecksum()
}
return nil
}
func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.UDPPacket) error {
@@ -345,25 +402,13 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
if packet.FragmentOffset() != 0 {
return E.New("ipv4: udp: fragment dropped")
}
if !header.Valid() {
return E.New("ipv4: udp: invalid packet")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
if s.router != nil {
routeSession := RouteSession{4, syscall.IPPROTO_UDP, source, destination}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter4{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv4 udp packet")
})
}
return nil
}
data := buf.As(header.Payload())
if data.Len() == 0 {
@@ -377,31 +422,25 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
headerLen := packet.HeaderLen() + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter4{s.tun, headerCopy, source}
return &systemUDPPacketWriter4{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil
}
func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
if !header.Valid() {
return E.New("ipv6: udp: invalid packet")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
if s.router != nil {
routeSession := RouteSession{6, syscall.IPPROTO_UDP, source, destination}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter6{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv6 udp packet")
})
}
return nil
}
data := buf.As(header.Payload())
if data.Len() == 0 {
@@ -415,27 +454,18 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
headerLen := len(packet) - int(header.Length()) + clashtcpip.UDPHeaderSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
return &systemUDPPacketWriter6{s.tun, headerCopy, source}
return &systemUDPPacketWriter6{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source,
s.txChecksumOffload,
}
})
return nil
}
func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
if s.router != nil {
routeSession := RouteSession{4, clashtcpip.ICMP, netip.AddrPortFrom(packet.SourceIP(), 0), netip.AddrPortFrom(packet.DestinationIP(), 0)}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemICMPDirectPacketWriter4{s.tun, packet.SourceIP()})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv4 icmp packet")
})
}
}
if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 {
return nil
}
@@ -445,25 +475,10 @@ func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip
packet.SetDestinationIP(sourceAddress)
header.ResetChecksum()
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
return nil
}
func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error {
if s.router != nil {
routeSession := RouteSession{6, clashtcpip.ICMPv6, netip.AddrPortFrom(packet.SourceIP(), 0), netip.AddrPortFrom(packet.DestinationIP(), 0)}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemICMPDirectPacketWriter6{s.tun, packet.SourceIP()})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv6 icmp packet")
})
}
}
if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 {
return nil
}
@@ -473,102 +488,21 @@ func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip
packet.SetDestinationIP(sourceAddress)
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(s.tun.Write(packet))
}
type systemTCPDirectPacketWriter4 struct {
tun Tun
source netip.AddrPort
}
func (w *systemTCPDirectPacketWriter4) WritePacket(p []byte) error {
packet := clashtcpip.IPv4Packet(p)
header := clashtcpip.TCPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemTCPDirectPacketWriter6 struct {
tun Tun
source netip.AddrPort
}
func (w *systemTCPDirectPacketWriter6) WritePacket(p []byte) error {
packet := clashtcpip.IPv6Packet(p)
header := clashtcpip.TCPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemUDPDirectPacketWriter4 struct {
tun Tun
source netip.AddrPort
}
func (w *systemUDPDirectPacketWriter4) WritePacket(p []byte) error {
packet := clashtcpip.IPv4Packet(p)
header := clashtcpip.UDPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemUDPDirectPacketWriter6 struct {
tun Tun
source netip.AddrPort
}
func (w *systemUDPDirectPacketWriter6) WritePacket(p []byte) error {
packet := clashtcpip.IPv6Packet(p)
header := clashtcpip.UDPPacket(packet.Payload())
packet.SetDestinationIP(w.source.Addr())
header.SetDestinationPort(w.source.Port())
header.ResetChecksum(packet.PseudoSum())
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemICMPDirectPacketWriter4 struct {
tun Tun
source netip.Addr
}
func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error {
packet := clashtcpip.IPv4Packet(p)
packet.SetDestinationIP(w.source)
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
}
type systemICMPDirectPacketWriter6 struct {
tun Tun
source netip.Addr
}
func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error {
packet := clashtcpip.IPv6Packet(p)
packet.SetDestinationIP(w.source)
packet.ResetChecksum()
return common.Error(w.tun.Write(packet))
return nil
}
type systemUDPPacketWriter4 struct {
tun Tun
header []byte
source netip.AddrPort
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}
func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.StackNewSize(len(w.header) + buffer.Len())
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv4Packet(newPacket.Bytes())
@@ -579,20 +513,33 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + clashtcpip.UDPHeaderSize))
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
if !w.txChecksumOffload {
udpHdr.ResetChecksum(ipHdr.PseudoSum())
ipHdr.ResetChecksum()
} else {
udpHdr.OffloadChecksum()
ipHdr.ResetChecksum()
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemUDPPacketWriter6 struct {
tun Tun
header []byte
source netip.AddrPort
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}
func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.StackNewSize(len(w.header) + buffer.Len())
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := clashtcpip.IPv6Packet(newPacket.Bytes())
@@ -604,6 +551,15 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
udpHdr.ResetChecksum(ipHdr.PseudoSum())
if !w.txChecksumOffload {
udpHdr.ResetChecksum(ipHdr.PseudoSum())
} else {
udpHdr.OffloadChecksum()
}
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET6
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package tun
func fixWindowsFirewall() error {
return nil
}

25
stack_system_windows.go Normal file
View File

@@ -0,0 +1,25 @@
package tun
import (
"os"
"path/filepath"
"github.com/sagernet/sing-tun/internal/winfw"
)
func fixWindowsFirewall() error {
absPath, err := filepath.Abs(os.Args[0])
if err != nil {
return err
}
rule := winfw.FWRule{
Name: "sing-tun (" + absPath + ")",
ApplicationName: absPath,
Enabled: true,
Protocol: winfw.NET_FW_IP_PROTOCOL_TCP,
Direction: winfw.NET_FW_RULE_DIR_IN,
Action: winfw.NET_FW_ACTION_ALLOW,
}
_, err = winfw.FirewallRuleAddAdvanced(rule)
return err
}

54
tun.go
View File

@@ -10,6 +10,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ranges"
)
@@ -22,6 +23,7 @@ type Handler interface {
type Tun interface {
io.ReadWriter
N.VectorisedWriter
Close() error
}
@@ -30,23 +32,41 @@ type WinTun interface {
ReadPacket() ([]byte, func(), error)
}
type LinuxTUN interface {
Tun
N.FrontHeadroom
BatchSize() int
BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error)
BatchWrite(buffers [][]byte, offset int) error
TXChecksumOffload() bool
}
type Options struct {
Name string
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
MTU uint32
AutoRoute bool
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix
IncludeUID []ranges.Range[uint32]
ExcludeUID []ranges.Range[uint32]
IncludeAndroidUser []int
IncludePackage []string
ExcludePackage []string
InterfaceMonitor DefaultInterfaceMonitor
TableIndex int
FileDescriptor int
Name string
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
MTU uint32
GSO bool
AutoRoute bool
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix
Inet4RouteExcludeAddress []netip.Prefix
Inet6RouteExcludeAddress []netip.Prefix
IncludeInterface []string
ExcludeInterface []string
IncludeUID []ranges.Range[uint32]
ExcludeUID []ranges.Range[uint32]
IncludeAndroidUser []int
IncludePackage []string
ExcludePackage []string
InterfaceMonitor DefaultInterfaceMonitor
TableIndex int
FileDescriptor int
Logger logger.Logger
// No work for TCP, do not use.
_TXChecksumOffload bool
}
func CalculateInterfaceName(name string) (tunName string) {
@@ -65,7 +85,7 @@ func CalculateInterfaceName(name string) (tunName string) {
for _, netInterface := range interfaces {
if strings.HasPrefix(netInterface.Name, tunName) {
index, parseErr := strconv.ParseInt(netInterface.Name[len(tunName):], 10, 16)
if parseErr == nil {
if parseErr == nil && int(index) >= tunIndex {
tunIndex = int(index) + 1
}
}

View File

@@ -5,11 +5,11 @@ import (
"net"
"net/netip"
"os"
"runtime"
"syscall"
"unsafe"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
@@ -25,8 +25,8 @@ type NativeTun struct {
tunFile *os.File
tunWriter N.VectorisedWriter
mtu uint32
inet4Address string
inet6Address string
inet4Address [4]byte
inet6Address [16]byte
}
func New(options Options) (Tun, error) {
@@ -57,48 +57,40 @@ func New(options Options) (Tun, error) {
mtu: options.MTU,
}
if len(options.Inet4Address) > 0 {
nativeTun.inet4Address = string(options.Inet4Address[0].Addr().AsSlice())
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
}
if len(options.Inet6Address) > 0 {
nativeTun.inet6Address = string(options.Inet6Address[0].Addr().AsSlice())
nativeTun.inet6Address = options.Inet6Address[0].Addr().As16()
}
var ok bool
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
if !ok {
panic("create vectorised writer")
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
return nativeTun, nil
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
/*n, err = t.tunFile.Read(p)
if n < 4 {
return 0, err
}
copy(p[:], p[4:])
return n - 4, err*/
return t.tunFile.Read(p)
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
return t.tunFile.Write(p)
}
var (
packetHeader4 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET}
packetHeader6 = [4]byte{0x00, 0x00, 0x00, unix.AF_INET6}
)
func (t *NativeTun) Write(p []byte) (n int, err error) {
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
if p[0]>>4 == 4 {
if buffers[0].Byte(0)>>4 == 4 {
packetHeader = packetHeader4[:]
} else {
packetHeader = packetHeader6[:]
}
_, err = bufio.WriteVectorised(t.tunWriter, [][]byte{packetHeader, p})
if err == nil {
n = len(p)
}
return
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
}
func (t *NativeTun) Close() error {
@@ -248,43 +240,16 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
}
}
if options.AutoRoute {
if len(options.Inet4Address) > 0 {
var routes []netip.Prefix
if len(options.Inet4RouteAddress) > 0 {
routes = append(options.Inet4RouteAddress, netip.PrefixFrom(options.Inet4Address[0].Addr().Next(), 32))
var routeRanges []netip.Prefix
routeRanges, err = options.BuildAutoRouteRanges(false)
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = addRoute(routeRange, options.Inet4Address[0].Addr())
} else {
routes = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 0, 0, 0}), 8),
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 0, 0, 0}), 7),
netip.PrefixFrom(netip.AddrFrom4([4]byte{4, 0, 0, 0}), 6),
netip.PrefixFrom(netip.AddrFrom4([4]byte{8, 0, 0, 0}), 5),
netip.PrefixFrom(netip.AddrFrom4([4]byte{16, 0, 0, 0}), 4),
netip.PrefixFrom(netip.AddrFrom4([4]byte{32, 0, 0, 0}), 3),
netip.PrefixFrom(netip.AddrFrom4([4]byte{64, 0, 0, 0}), 2),
netip.PrefixFrom(netip.AddrFrom4([4]byte{128, 0, 0, 0}), 1),
}
err = addRoute(routeRange, options.Inet6Address[0].Addr())
}
for _, subnet := range routes {
err = addRoute(subnet, options.Inet4Address[0].Addr())
if err != nil {
return E.Cause(err, "add ipv4 route ", subnet)
}
}
}
if len(options.Inet6Address) > 0 {
var routes []netip.Prefix
if len(options.Inet6RouteAddress) > 0 {
routes = append(options.Inet6RouteAddress, netip.PrefixFrom(options.Inet6Address[0].Addr().Next(), 128))
} else {
routes = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom16([16]byte{32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), 3),
}
}
for _, subnet := range routes {
err = addRoute(subnet, options.Inet6Address[0].Addr())
if err != nil {
return E.Cause(err, "add ipv6 route ", subnet)
}
if err != nil {
return E.Cause(err, "add route: ", routeRange)
}
}
flushDNSCache()

View File

@@ -3,14 +3,11 @@
package tun
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
var _ GVisorTun = (*NativeTun)(nil)
@@ -39,7 +36,7 @@ func (e *DarwinEndpoint) LinkAddress() tcpip.LinkAddress {
}
func (e *DarwinEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
return stack.CapabilityRXChecksumOffload
}
func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -54,37 +51,33 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
func (e *DarwinEndpoint) dispatchLoop() {
_buffer := buf.StackNewSize(int(e.tun.mtu) + 4)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
data := buffer.FreeBytes()
packetBuffer := make([]byte, e.tun.mtu+PacketOffset)
for {
n, err := e.tun.tunFile.Read(data)
n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil {
break
}
packet := data[4:n]
packet := packetBuffer[PacketOffset:n]
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
if header.IPv4(packet).DestinationAddress() == tcpip.Address(e.tun.inet4Address) {
e.tun.tunFile.Write(data[:n])
if header.IPv4(packet).DestinationAddress().As4() == e.tun.inet4Address {
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
if header.IPv6(packet).DestinationAddress() == tcpip.Address(e.tun.inet6Address) {
e.tun.tunFile.Write(data[:n])
if header.IPv6(packet).DestinationAddress().As16() == e.tun.inet6Address {
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
default:
e.tun.tunFile.Write(data[:n])
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(data[4:n]),
Payload: buffer.MakeWithData(packetBuffer[4:n]),
IsForwardedPacket: true,
})
pkt.NetworkProtocolNumber = networkProtocol
@@ -112,17 +105,14 @@ func (e *DarwinEndpoint) ARPHardwareType() header.ARPHardwareType {
func (e *DarwinEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (e *DarwinEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {
var n int
for _, packet := range packetBufferList.AsSlice() {
var packetHeader []byte
switch packet.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
packetHeader = packetHeader4[:]
case header.IPv6ProtocolNumber:
packetHeader = packetHeader6[:]
}
_, err := bufio.WriteVectorised(e.tun.tunWriter, append([][]byte{packetHeader}, packet.AsSlices()...))
_, err := bufio.WriteVectorised(e.tun, packet.AsSlices())
if err != nil {
return n, &tcpip.ErrAborted{}
}

View File

@@ -7,12 +7,16 @@ import (
"os"
"os/exec"
"runtime"
"sync"
"syscall"
"unsafe"
"github.com/sagernet/netlink"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing/common/x/list"
@@ -20,17 +24,29 @@ import (
"golang.org/x/sys/unix"
)
var _ LinuxTUN = (*NativeTun)(nil)
type NativeTun struct {
tunFd int
tunFile *os.File
tunWriter N.VectorisedWriter
interfaceCallback *list.Element[DefaultInterfaceUpdateCallback]
options Options
ruleIndex6 []int
gsoEnabled bool
gsoBuffer []byte
gsoToWrite []int
gsoReadAccess sync.Mutex
tcpGROAccess sync.Mutex
tcp4GROTable *tcpGROTable
tcp6GROTable *tcpGROTable
txChecksumOffload bool
}
func New(options Options) (Tun, error) {
var nativeTun *NativeTun
if options.FileDescriptor == 0 {
tunFd, err := open(options.Name)
tunFd, err := open(options.Name, options.GSO)
if err != nil {
return nil, err
}
@@ -38,36 +54,129 @@ func New(options Options) (Tun, error) {
if err != nil {
return nil, E.Errors(err, unix.Close(tunFd))
}
nativeTun := &NativeTun{
nativeTun = &NativeTun{
tunFd: tunFd,
tunFile: os.NewFile(uintptr(tunFd), "tun"),
options: options,
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
err = nativeTun.configure(tunLink)
if err != nil {
return nil, E.Errors(err, unix.Close(tunFd))
}
return nativeTun, nil
} else {
nativeTun := &NativeTun{
nativeTun = &NativeTun{
tunFd: options.FileDescriptor,
tunFile: os.NewFile(uintptr(options.FileDescriptor), "tun"),
options: options,
}
runtime.SetFinalizer(nativeTun.tunFile, nil)
return nativeTun, nil
}
var ok bool
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
if !ok {
panic("create vectorised writer")
}
return nativeTun, nil
}
func (t *NativeTun) FrontHeadroom() int {
if t.gsoEnabled {
return virtioNetHdrLen
}
return 0
}
func (t *NativeTun) Read(p []byte) (n int, err error) {
return t.tunFile.Read(p)
if t.gsoEnabled {
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
var sizes [1]int
n, err = handleVirtioRead(t.gsoBuffer[:n], [][]byte{p}, sizes[:], 0)
if err != nil {
return
}
if n == 0 {
return
}
n = sizes[0]
return
} else {
return t.tunFile.Read(p)
}
}
func (t *NativeTun) Write(p []byte) (n int, err error) {
if t.gsoEnabled {
err = t.BatchWrite([][]byte{p}, virtioNetHdrLen)
if err != nil {
return
}
n = len(p)
return
}
return t.tunFile.Write(p)
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
if t.gsoEnabled {
n := buf.LenMulti(buffers)
buffer := buf.NewSize(virtioNetHdrLen + n)
buffer.Truncate(virtioNetHdrLen)
buf.CopyMulti(buffer.Extend(n), buffers)
_, err := t.tunFile.Write(buffer.Bytes())
buffer.Release()
return err
} else {
return t.tunWriter.WriteVectorised(buffers)
}
}
func (t *NativeTun) BatchSize() int {
if !t.gsoEnabled {
return 1
}
/* // Not works on some devices: https://github.com/SagerNet/sing-box/issues/1605
batchSize := int(gsoMaxSize/t.options.MTU) * 2
if batchSize > idealBatchSize {
batchSize = idealBatchSize
}
return batchSize*/
return idealBatchSize
}
func (t *NativeTun) BatchRead(buffers [][]byte, offset int, readN []int) (n int, err error) {
t.gsoReadAccess.Lock()
defer t.gsoReadAccess.Unlock()
n, err = t.tunFile.Read(t.gsoBuffer)
if err != nil {
return
}
return handleVirtioRead(t.gsoBuffer[:n], buffers, readN, offset)
}
func (t *NativeTun) BatchWrite(buffers [][]byte, offset int) error {
t.tcpGROAccess.Lock()
defer func() {
t.tcp4GROTable.reset()
t.tcp6GROTable.reset()
t.tcpGROAccess.Unlock()
}()
t.gsoToWrite = t.gsoToWrite[:0]
err := handleGRO(buffers, offset, t.tcp4GROTable, t.tcp6GROTable, &t.gsoToWrite)
if err != nil {
return err
}
offset -= virtioNetHdrLen
for _, bufferIndex := range t.gsoToWrite {
_, err = t.tunFile.Write(buffers[bufferIndex][offset:])
if err != nil {
return err
}
}
return nil
}
var controlPath string
func init() {
@@ -80,7 +189,7 @@ func init() {
}
}
func open(name string) (int, error) {
func open(name string, vnetHdr bool) (int, error) {
fd, err := unix.Open(controlPath, unix.O_RDWR, 0)
if err != nil {
return -1, err
@@ -94,6 +203,9 @@ func open(name string) (int, error) {
copy(ifr.name[:], name)
ifr.flags = unix.IFF_TUN | unix.IFF_NO_PI
if vnetHdr {
ifr.flags |= unix.IFF_VNET_HDR
}
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), unix.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
unix.Close(fd)
@@ -136,6 +248,46 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
}
}
if t.options.GSO {
var vnetHdrEnabled bool
vnetHdrEnabled, err = checkVNETHDREnabled(t.tunFd, t.options.Name)
if err != nil {
return E.Cause(err, "enable offload: check IFF_VNET_HDR enabled")
}
if !vnetHdrEnabled {
return E.Cause(err, "enable offload: IFF_VNET_HDR not enabled")
}
err = setTCPOffload(t.tunFd)
if err != nil {
return err
}
t.gsoEnabled = true
t.gsoBuffer = make([]byte, virtioNetHdrLen+int(gsoMaxSize))
t.tcp4GROTable = newTCPGROTable()
t.tcp6GROTable = newTCPGROTable()
}
var rxChecksumOffload bool
rxChecksumOffload, err = checkChecksumOffload(t.options.Name, unix.ETHTOOL_GRXCSUM)
if err == nil && !rxChecksumOffload {
_ = setChecksumOffload(t.options.Name, unix.ETHTOOL_SRXCSUM)
}
if t.options._TXChecksumOffload {
var txChecksumOffload bool
txChecksumOffload, err = checkChecksumOffload(t.options.Name, unix.ETHTOOL_GTXCSUM)
if err != nil {
return err
}
if err == nil && !txChecksumOffload {
err = setChecksumOffload(t.options.Name, unix.ETHTOOL_STXCSUM)
if err != nil {
return err
}
}
t.txChecksumOffload = true
}
err = netlink.LinkSetUp(tunLink)
if err != nil {
return err
@@ -167,7 +319,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
return err
}
setSearchDomainForSystemdResolved(t.options.Name)
t.setSearchDomainForSystemdResolved()
if t.options.AutoRoute && runtime.GOOS == "android" {
t.interfaceCallback = t.options.InterfaceMonitor.RegisterCallback(t.routeUpdate)
@@ -182,57 +334,29 @@ func (t *NativeTun) Close() error {
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
}
func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route {
var routes []netlink.Route
if len(t.options.Inet4Address) > 0 {
if t.options.AutoRoute {
if len(t.options.Inet4RouteAddress) > 0 {
for _, addr := range t.options.Inet4RouteAddress {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: addr.Addr().AsSlice(),
Mask: net.CIDRMask(addr.Bits(), 32),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
} else {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
}
func (t *NativeTun) TXChecksumOffload() bool {
return t.txChecksumOffload
}
func prefixToIPNet(prefix netip.Prefix) *net.IPNet {
return &net.IPNet{
IP: prefix.Addr().AsSlice(),
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
}
if len(t.options.Inet6Address) > 0 {
if len(t.options.Inet6RouteAddress) > 0 {
for _, addr := range t.options.Inet6RouteAddress {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: addr.Addr().AsSlice(),
Mask: net.CIDRMask(addr.Bits(), 128),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
} else {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
}
func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return nil, err
}
return routes
return common.Map(routeRanges, func(it netip.Prefix) netlink.Route {
return netlink.Route{
Dst: prefixToIPNet(it),
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
}
}), nil
}
const (
@@ -317,6 +441,78 @@ func (t *NativeTun) rules() []*netlink.Rule {
priority6++
}
}
if len(t.options.IncludeInterface) > 0 {
matchPriority := priority + 2*len(t.options.IncludeInterface) + 1
for _, includeInterface := range t.options.IncludeInterface {
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.IifName = includeInterface
it.Goto = matchPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.IifName = includeInterface
it.Goto = matchPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
}
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.Family = unix.AF_INET
it.Goto = nopPriority
rules = append(rules, it)
priority++
it = netlink.NewRule()
it.Priority = matchPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.Family = unix.AF_INET6
it.Goto = nopPriority
rules = append(rules, it)
priority6++
it = netlink.NewRule()
it.Priority = matchPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
} else if len(t.options.ExcludeInterface) > 0 {
for _, excludeInterface := range t.options.ExcludeInterface {
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.IifName = excludeInterface
it.Goto = nopPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.IifName = excludeInterface
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
}
}
if runtime.GOOS == "android" && t.options.InterfaceMonitor.AndroidVPNEnabled() {
const protectedFromVPN = 0x20000
@@ -376,16 +572,34 @@ func (t *NativeTun) rules() []*netlink.Rule {
rules = append(rules, it)
}
priority++
}
/*if p6 {
it = netlink.NewRule()
it.Priority = priority
it.Dst = t.options.Inet6Address.Masked()
it.Table = tunTableIndex
it.Table = t.options.TableIndex
it.SuppressPrefixlen = 0
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.Table = t.options.TableIndex
it.SuppressPrefixlen = 0
it.Family = unix.AF_INET6
rules = append(rules, it)
}*/
if p4 {
priority6++
}
if p4 && !t.options.StrictRoute {
it = netlink.NewRule()
it.Priority = priority
it.Invert = true
it.Dport = netlink.NewRulePortRange(53, 53)
it.Table = unix.RT_TABLE_MAIN
it.SuppressPrefixlen = 0
it.Family = unix.AF_INET
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority
it.IPProto = syscall.IPPROTO_ICMP
@@ -394,7 +608,16 @@ func (t *NativeTun) rules() []*netlink.Rule {
rules = append(rules, it)
priority++
}
if p6 {
if p6 && !t.options.StrictRoute {
it = netlink.NewRule()
it.Priority = priority6
it.Invert = true
it.Dport = netlink.NewRulePortRange(53, 53)
it.Table = unix.RT_TABLE_MAIN
it.SuppressPrefixlen = 0
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority6
it.IPProto = syscall.IPPROTO_ICMPV6
@@ -403,26 +626,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
rules = append(rules, it)
priority6++
}
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.Invert = true
it.Dport = netlink.NewRulePortRange(53, 53)
it.Table = unix.RT_TABLE_MAIN
it.SuppressPrefixlen = 0
it.Family = unix.AF_INET
rules = append(rules, it)
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.Invert = true
it.Dport = netlink.NewRulePortRange(53, 53)
it.Table = unix.RT_TABLE_MAIN
it.SuppressPrefixlen = 0
it.Family = unix.AF_INET6
rules = append(rules, it)
}
}
if p4 {
@@ -462,36 +665,42 @@ func (t *NativeTun) rules() []*netlink.Rule {
priority++
}
if p6 {
// FIXME: this match connections from public address
if !t.options.StrictRoute {
for _, address := range t.options.Inet6Address {
it = netlink.NewRule()
it.Priority = priority6
it.IifName = "lo"
it.Src = address.Masked()
it.Table = t.options.TableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
}
priority6++
it = netlink.NewRule()
it.Priority = priority6
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority6
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1)
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
it = netlink.NewRule()
it.Priority = priority6
it.Table = t.options.TableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
/*it = netlink.NewRule()
it.Priority = priority
it.Invert = true
it.IifName = "lo"
it.Table = tunTableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.IPv6Unspecified(), 128) // not working
it.Table = tunTableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = t.options.Inet6Address.Masked()
it.Table = tunTableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)*/
priority6++
}
if p4 {
@@ -510,7 +719,11 @@ func (t *NativeTun) rules() []*netlink.Rule {
}
func (t *NativeTun) setRoute(tunLink netlink.Link) error {
for i, route := range t.routes(tunLink) {
routes, err := t.routes(tunLink)
if err != nil {
return err
}
for i, route := range routes {
err := netlink.RouteAdd(&route)
if err != nil {
return E.Cause(err, "add route ", i)
@@ -541,8 +754,10 @@ func (t *NativeTun) unsetRoute() error {
}
func (t *NativeTun) unsetRoute0(tunLink netlink.Link) error {
for _, route := range t.routes(tunLink) {
_ = netlink.RouteDel(&route)
if routes, err := t.routes(tunLink); err == nil {
for _, route := range routes {
_ = netlink.RouteDel(&route)
}
}
return nil
}
@@ -588,21 +803,33 @@ func (t *NativeTun) resetRules() error {
return t.setRules()
}
func (t *NativeTun) routeUpdate(event int) error {
func (t *NativeTun) routeUpdate(event int) {
if event&EventAndroidVPNUpdate == 0 {
return nil
return
}
err := t.resetRules()
if err != nil {
return E.Cause(err, "reset route")
if t.options.Logger != nil {
t.options.Logger.Error(E.Cause(err, "reset route"))
}
}
return nil
}
func setSearchDomainForSystemdResolved(interfaceName string) {
func (t *NativeTun) setSearchDomainForSystemdResolved() {
ctlPath, err := exec.LookPath("resolvectl")
if err != nil {
return
}
shell.Exec(ctlPath, "domain", interfaceName, "~.").Run()
var dnsServer []netip.Addr
if len(t.options.Inet4Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next())
}
if len(t.options.Inet6Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next())
}
go shell.Exec(ctlPath, "domain", t.options.Name, "~.").Run()
if t.options.AutoRoute {
go shell.Exec(ctlPath, "default-route", t.options.Name, "true").Run()
go shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Run()
}
}

84
tun_linux_flags.go Normal file
View File

@@ -0,0 +1,84 @@
//go:build linux
package tun
import (
"os"
"syscall"
"unsafe"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
func checkVNETHDREnabled(fd int, name string) (bool, error) {
ifr, err := unix.NewIfreq(name)
if err != nil {
return false, err
}
err = unix.IoctlIfreq(fd, unix.TUNGETIFF, ifr)
if err != nil {
return false, os.NewSyscallError("TUNGETIFF", err)
}
return ifr.Uint16()&unix.IFF_VNET_HDR != 0, nil
}
func setTCPOffload(fd int) error {
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
err := unix.IoctlSetInt(fd, unix.TUNSETOFFLOAD, tunOffloads)
if err != nil {
return E.Cause(os.NewSyscallError("TUNSETOFFLOAD", err), "enable offload")
}
return nil
}
type ifreqData struct {
ifrName [unix.IFNAMSIZ]byte
ifrData uintptr
}
type ethtoolValue struct {
cmd uint32
data uint32
}
//go:linkname ioctlPtr golang.org/x/sys/unix.ioctlPtr
func ioctlPtr(fd int, req uint, arg unsafe.Pointer) (err error)
func checkChecksumOffload(name string, cmd uint32) (bool, error) {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return false, err
}
defer syscall.Close(fd)
ifr := ifreqData{}
copy(ifr.ifrName[:], name)
data := ethtoolValue{cmd: cmd}
ifr.ifrData = uintptr(unsafe.Pointer(&data))
err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr))
if err != nil {
return false, os.NewSyscallError("SIOCETHTOOL ETHTOOL_GTXCSUM", err)
}
return data.data == 0, nil
}
func setChecksumOffload(name string, cmd uint32) error {
fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_IP)
if err != nil {
return err
}
defer syscall.Close(fd)
ifr := ifreqData{}
copy(ifr.ifrName[:], name)
data := ethtoolValue{cmd: cmd, data: 0}
ifr.ifrData = uintptr(unsafe.Pointer(&data))
err = ioctlPtr(fd, unix.SIOCETHTOOL, unsafe.Pointer(&ifr))
if err != nil {
return os.NewSyscallError("SIOCETHTOOL ETHTOOL_STXCSUM", err)
}
return nil
}

View File

@@ -3,15 +3,26 @@
package tun
import (
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/link/fdbased"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
)
var _ GVisorTun = (*NativeTun)(nil)
func (t *NativeTun) NewEndpoint() (stack.LinkEndpoint, error) {
if t.gsoEnabled {
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
GSOMaxSize: gsoMaxSize,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
}
return fdbased.New(&fdbased.Options{
FDs: []int{t.tunFd},
MTU: t.options.MTU,
FDs: []int{t.tunFd},
MTU: t.options.MTU,
RXChecksumOffload: true,
TXChecksumOffload: t.txChecksumOffload,
})
}

768
tun_linux_offload.go Normal file
View File

@@ -0,0 +1,768 @@
//go:build linux
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"unsafe"
"github.com/sagernet/sing-tun/internal/clashtcpip"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
const (
gsoMaxSize = 65536
tcpFlagsOffset = 13
idealBatchSize = 128
)
const (
tcpFlagFIN uint8 = 0x01
tcpFlagPSH uint8 = 0x08
tcpFlagACK uint8 = 0x10
)
// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The
// kernel symbol is virtio_net_hdr.
type virtioNetHdr struct {
flags uint8
gsoType uint8
hdrLen uint16
gsoSize uint16
csumStart uint16
csumOffset uint16
}
func (v *virtioNetHdr) decode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen])
return nil
}
func (v *virtioNetHdr) encode(b []byte) error {
if len(b) < virtioNetHdrLen {
return io.ErrShortBuffer
}
copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen))
return nil
}
const (
// virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the
// shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr).
virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{}))
)
// flowKey represents the key for a flow.
type flowKey struct {
srcAddr, dstAddr [16]byte
srcPort, dstPort uint16
rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows.
}
// tcpGROTable holds flow and coalescing information for the purposes of GRO.
type tcpGROTable struct {
itemsByFlow map[flowKey][]tcpGROItem
itemsPool [][]tcpGROItem
}
func newTCPGROTable() *tcpGROTable {
t := &tcpGROTable{
itemsByFlow: make(map[flowKey][]tcpGROItem, idealBatchSize),
itemsPool: make([][]tcpGROItem, idealBatchSize),
}
for i := range t.itemsPool {
t.itemsPool[i] = make([]tcpGROItem, 0, idealBatchSize)
}
return t
}
func newFlowKey(pkt []byte, srcAddr, dstAddr, tcphOffset int) flowKey {
key := flowKey{}
addrSize := dstAddr - srcAddr
copy(key.srcAddr[:], pkt[srcAddr:dstAddr])
copy(key.dstAddr[:], pkt[dstAddr:dstAddr+addrSize])
key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:])
key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:])
key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:])
return key
}
// lookupOrInsert looks up a flow for the provided packet and metadata,
// returning the packets found for the flow, or inserting a new one if none
// is found.
func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
items, ok := t.itemsByFlow[key]
if ok {
return items, ok
}
// TODO: insert() performs another map lookup. This could be rearranged to avoid.
t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex)
return nil, false
}
// insert an item in the table for the provided packet and packet metadata.
func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) {
key := newFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset)
item := tcpGROItem{
key: key,
bufsIndex: uint16(bufsIndex),
gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])),
iphLen: uint8(tcphOffset),
tcphLen: uint8(tcphLen),
sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]),
pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0,
}
items, ok := t.itemsByFlow[key]
if !ok {
items = t.newItems()
}
items = append(items, item)
t.itemsByFlow[key] = items
}
func (t *tcpGROTable) updateAt(item tcpGROItem, i int) {
items, _ := t.itemsByFlow[item.key]
items[i] = item
}
func (t *tcpGROTable) deleteAt(key flowKey, i int) {
items, _ := t.itemsByFlow[key]
items = append(items[:i], items[i+1:]...)
t.itemsByFlow[key] = items
}
// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime
// of a GRO evaluation across a vector of packets.
type tcpGROItem struct {
key flowKey
sentSeq uint32 // the sequence number
bufsIndex uint16 // the index into the original bufs slice
numMerged uint16 // the number of packets merged into this item
gsoSize uint16 // payload size
iphLen uint8 // ip header len
tcphLen uint8 // tcp header len
pshSet bool // psh flag is set
}
func (t *tcpGROTable) newItems() []tcpGROItem {
var items []tcpGROItem
items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1]
return items
}
func (t *tcpGROTable) reset() {
for k, items := range t.itemsByFlow {
items = items[:0]
t.itemsPool = append(t.itemsPool, items)
delete(t.itemsByFlow, k)
}
}
// canCoalesce represents the outcome of checking if two TCP packets are
// candidates for coalescing.
type canCoalesce int
const (
coalescePrepend canCoalesce = -1
coalesceUnavailable canCoalesce = 0
coalesceAppend canCoalesce = 1
)
// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet
// described by item. This function makes considerations that match the kernel's
// GRO self tests, which can be found in tools/testing/selftests/net/gro.c.
func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce {
pktTarget := bufs[item.bufsIndex][bufsOffset:]
if tcphLen != item.tcphLen {
// cannot coalesce with unequal tcp options len
return coalesceUnavailable
}
if tcphLen > 20 {
if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) {
// cannot coalesce with unequal tcp options
return coalesceUnavailable
}
}
if pkt[0]>>4 == 6 {
if pkt[0] != pktTarget[0] || pkt[1]>>4 != pktTarget[1]>>4 {
// cannot coalesce with unequal Traffic class values
return coalesceUnavailable
}
if pkt[7] != pktTarget[7] {
// cannot coalesce with unequal Hop limit values
return coalesceUnavailable
}
} else {
if pkt[1] != pktTarget[1] {
// cannot coalesce with unequal ToS values
return coalesceUnavailable
}
if pkt[6]>>5 != pktTarget[6]>>5 {
// cannot coalesce with unequal DF or reserved bits. MF is checked
// further up the stack.
return coalesceUnavailable
}
if pkt[8] != pktTarget[8] {
// cannot coalesce with unequal TTL values
return coalesceUnavailable
}
}
// seq adjacency
lhsLen := item.gsoSize
lhsLen += item.numMerged * item.gsoSize
if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective
if item.pshSet {
// We cannot append to a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 {
// A smaller than gsoSize packet has been appended previously.
// Nothing can come after a smaller packet on the end.
return coalesceUnavailable
}
if gsoSize > item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
return coalesceAppend
} else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective
if pshSet {
// We cannot prepend with a segment that has the PSH flag set, PSH
// can only be set on the final segment in a reassembled group.
return coalesceUnavailable
}
if gsoSize < item.gsoSize {
// We cannot have a larger packet following a smaller one.
return coalesceUnavailable
}
if gsoSize > item.gsoSize && item.numMerged > 0 {
// There's at least one previous merge, and we're larger than all
// previous. This would put multiple smaller packets on the end.
return coalesceUnavailable
}
return coalescePrepend
}
return coalesceUnavailable
}
func tcpChecksumValid(pkt []byte, iphLen uint8, isV6 bool) bool {
srcAddrAt := ipv4SrcAddrOffset
addrSize := 4
if isV6 {
srcAddrAt = ipv6SrcAddrOffset
addrSize = 16
}
tcpTotalLen := uint16(len(pkt) - int(iphLen))
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], tcpTotalLen)
return ^checksumFold(pkt[iphLen:], tcpCSumNoFold) == 0
}
// coalesceResult represents the result of attempting to coalesce two TCP
// packets.
type coalesceResult int
const (
coalesceInsufficientCap coalesceResult = iota
coalescePSHEnding
coalesceItemInvalidCSum
coalescePktInvalidCSum
coalesceSuccess
)
// coalesceTCPPackets attempts to coalesce pkt with the packet described by
// item, returning the outcome. This function may swap bufs elements in the
// event of a prepend as item's bufs index is already being tracked for writing
// to a Device.
func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult {
var pktHead []byte // the packet that will end up at the front
headersLen := item.iphLen + item.tcphLen
coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen)
// Copy data
if mode == coalescePrepend {
pktHead = pkt
if cap(pkt)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if pshSet {
return coalescePSHEnding
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
item.sentSeq = seq
extendBy := coalescedLen - len(pktHead)
bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...)
copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):])
// Flip the slice headers in bufs as part of prepend. The index of item
// is already being tracked for writing.
bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex]
} else {
pktHead = bufs[item.bufsIndex][bufsOffset:]
if cap(pktHead)-bufsOffset < coalescedLen {
// We don't want to allocate a new underlying array if capacity is
// too small.
return coalesceInsufficientCap
}
if item.numMerged == 0 {
if !tcpChecksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, isV6) {
return coalesceItemInvalidCSum
}
}
if !tcpChecksumValid(pkt, item.iphLen, isV6) {
return coalescePktInvalidCSum
}
if pshSet {
// We are appending a segment with PSH set.
item.pshSet = pshSet
pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH
}
extendBy := len(pkt) - int(headersLen)
bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...)
copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:])
}
if gsoSize > item.gsoSize {
item.gsoSize = gsoSize
}
item.numMerged++
return coalesceSuccess
}
const (
ipv4FlagMoreFragments uint8 = 0x20
)
const (
ipv4SrcAddrOffset = 12
ipv6SrcAddrOffset = 8
maxUint16 = 1<<16 - 1
)
type tcpGROResult int
const (
tcpGROResultNoop tcpGROResult = iota
tcpGROResultTableInsert
tcpGROResultCoalesced
)
// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with
// existing packets tracked in table. It returns a tcpGROResultNoop when no
// action was taken, tcpGROResultTableInsert when the evaluated packet was
// inserted into table, and tcpGROResultCoalesced when the evaluated packet was
// coalesced with another packet in table.
func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) tcpGROResult {
pkt := bufs[pktI][offset:]
if len(pkt) > maxUint16 {
// A valid IPv4 or IPv6 packet will never exceed this.
return tcpGROResultNoop
}
iphLen := int((pkt[0] & 0x0F) * 4)
if isV6 {
iphLen = 40
ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:]))
if ipv6HPayloadLen != len(pkt)-iphLen {
return tcpGROResultNoop
}
} else {
totalLen := int(binary.BigEndian.Uint16(pkt[2:]))
if totalLen != len(pkt) {
return tcpGROResultNoop
}
}
if len(pkt) < iphLen {
return tcpGROResultNoop
}
tcphLen := int((pkt[iphLen+12] >> 4) * 4)
if tcphLen < 20 || tcphLen > 60 {
return tcpGROResultNoop
}
if len(pkt) < iphLen+tcphLen {
return tcpGROResultNoop
}
if !isV6 {
if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 {
// no GRO support for fragmented segments for now
return tcpGROResultNoop
}
}
tcpFlags := pkt[iphLen+tcpFlagsOffset]
var pshSet bool
// not a candidate if any non-ACK flags (except PSH+ACK) are set
if tcpFlags != tcpFlagACK {
if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH {
return tcpGROResultNoop
}
pshSet = true
}
gsoSize := uint16(len(pkt) - tcphLen - iphLen)
// not a candidate if payload len is 0
if gsoSize < 1 {
return tcpGROResultNoop
}
seq := binary.BigEndian.Uint32(pkt[iphLen+4:])
srcAddrOffset := ipv4SrcAddrOffset
addrLen := 4
if isV6 {
srcAddrOffset = ipv6SrcAddrOffset
addrLen = 16
}
items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
if !existing {
return tcpGROResultNoop
}
for i := len(items) - 1; i >= 0; i-- {
// In the best case of packets arriving in order iterating in reverse is
// more efficient if there are multiple items for a given flow. This
// also enables a natural table.deleteAt() in the
// coalesceItemInvalidCSum case without the need for index tracking.
// This algorithm makes a best effort to coalesce in the event of
// unordered packets, where pkt may land anywhere in items from a
// sequence number perspective, however once an item is inserted into
// the table it is never compared across other items later.
item := items[i]
can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset)
if can != coalesceUnavailable {
result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6)
switch result {
case coalesceSuccess:
table.updateAt(item, i)
return tcpGROResultCoalesced
case coalesceItemInvalidCSum:
// delete the item with an invalid csum
table.deleteAt(item.key, i)
case coalescePktInvalidCSum:
// no point in inserting an item that we can't coalesce
return tcpGROResultNoop
default:
}
}
}
// failed to coalesce with any other packets; store the item in the flow
table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI)
return tcpGROResultTableInsert
}
func isTCP4NoIPOptions(b []byte) bool {
if len(b) < 40 {
return false
}
if b[0]>>4 != 4 {
return false
}
if b[0]&0x0F != 5 {
return false
}
if b[9] != unix.IPPROTO_TCP {
return false
}
return true
}
func isTCP6NoEH(b []byte) bool {
if len(b) < 60 {
return false
}
if b[0]>>4 != 6 {
return false
}
if b[6] != unix.IPPROTO_TCP {
return false
}
return true
}
// applyCoalesceAccounting updates bufs to account for coalescing based on the
// metadata found in table.
func applyCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable, isV6 bool) error {
for _, items := range table.itemsByFlow {
for _, item := range items {
if item.numMerged > 0 {
hdr := virtioNetHdr{
flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb
hdrLen: uint16(item.iphLen + item.tcphLen),
gsoSize: item.gsoSize,
csumStart: uint16(item.iphLen),
csumOffset: 16,
}
pkt := bufs[item.bufsIndex][offset:]
// Recalculate the total len (IPv4) or payload len (IPv6).
// Recalculate the (IPv4) header checksum.
if isV6 {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6
binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len
} else {
hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4
pkt[10], pkt[11] = 0, 0
binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length
iphCSum := ^checksumFold(pkt[:item.iphLen], 0) // compute IPv4 header checksum
binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field
}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
// Calculate the pseudo header checksum and place it at the TCP
// checksum offset. Downstream checksum offloading will combine
// this with computation of the tcp header and payload checksum.
addrLen := 4
addrOffset := ipv4SrcAddrOffset
if isV6 {
addrLen = 16
addrOffset = ipv6SrcAddrOffset
}
srcAddrAt := offset + addrOffset
srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen]
dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2]
psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen)))
binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksumFold([]byte{}, psum))
} else {
hdr := virtioNetHdr{}
err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:])
if err != nil {
return err
}
}
}
}
return nil
}
// handleGRO evaluates bufs for GRO, and writes the indices of the resulting
// packets into toWrite. toWrite, tcp4Table, and tcp6Table should initially be
// empty (but non-nil), and are passed in to save allocs as the caller may reset
// and recycle them across vectors of packets.
func handleGRO(bufs [][]byte, offset int, tcp4Table, tcp6Table *tcpGROTable, toWrite *[]int) error {
for i := range bufs {
if offset < virtioNetHdrLen || offset > len(bufs[i])-1 {
return errors.New("invalid offset")
}
var result tcpGROResult
switch {
case isTCP4NoIPOptions(bufs[i][offset:]): // ipv4 packets w/IP options do not coalesce
result = tcpGRO(bufs, offset, i, tcp4Table, false)
case isTCP6NoEH(bufs[i][offset:]): // ipv6 packets w/extension headers do not coalesce
result = tcpGRO(bufs, offset, i, tcp6Table, true)
}
switch result {
case tcpGROResultNoop:
hdr := virtioNetHdr{}
err := hdr.encode(bufs[i][offset-virtioNetHdrLen:])
if err != nil {
return err
}
fallthrough
case tcpGROResultTableInsert:
*toWrite = append(*toWrite, i)
}
}
err4 := applyCoalesceAccounting(bufs, offset, tcp4Table, false)
err6 := applyCoalesceAccounting(bufs, offset, tcp6Table, true)
return E.Errors(err4, err6)
}
// tcpTSO splits packets from in into outBuffs, writing the size of each
// element into sizes. It returns the number of buffers populated, and/or an
// error.
func tcpTSO(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int) (int, error) {
iphLen := int(hdr.csumStart)
srcAddrOffset := ipv6SrcAddrOffset
addrLen := 16
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
in[10], in[11] = 0, 0 // clear ipv4 header checksum
srcAddrOffset = ipv4SrcAddrOffset
addrLen = 4
}
tcpCSumAt := int(hdr.csumStart + hdr.csumOffset)
in[tcpCSumAt], in[tcpCSumAt+1] = 0, 0 // clear tcp checksum
firstTCPSeqNum := binary.BigEndian.Uint32(in[hdr.csumStart+4:])
nextSegmentDataAt := int(hdr.hdrLen)
i := 0
for ; nextSegmentDataAt < len(in); i++ {
if i == len(outBuffs) {
return i - 1, ErrTooManySegments
}
nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize)
if nextSegmentEnd > len(in) {
nextSegmentEnd = len(in)
}
segmentDataLen := nextSegmentEnd - nextSegmentDataAt
totalLen := int(hdr.hdrLen) + segmentDataLen
sizes[i] = totalLen
out := outBuffs[i][outOffset:]
copy(out, in[:iphLen])
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 {
// For IPv4 we are responsible for incrementing the ID field,
// updating the total len field, and recalculating the header
// checksum.
if i > 0 {
id := binary.BigEndian.Uint16(out[4:])
id += uint16(i)
binary.BigEndian.PutUint16(out[4:], id)
}
binary.BigEndian.PutUint16(out[2:], uint16(totalLen))
ipv4CSum := ^checksumFold(out[:iphLen], 0)
binary.BigEndian.PutUint16(out[10:], ipv4CSum)
} else {
// For IPv6 we are responsible for updating the payload length field.
binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen))
}
// TCP header
copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen])
tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i))
binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq)
if nextSegmentEnd != len(in) {
// FIN and PSH should only be set on last segment
clearFlags := tcpFlagFIN | tcpFlagPSH
out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags
}
// payload
copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd])
// TCP checksum
tcpHLen := int(hdr.hdrLen - hdr.csumStart)
tcpLenForPseudo := uint16(tcpHLen + segmentDataLen)
tcpCSumNoFold := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], tcpLenForPseudo)
tcpCSum := ^checksumFold(out[hdr.csumStart:totalLen], tcpCSumNoFold)
binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], tcpCSum)
nextSegmentDataAt += int(hdr.gsoSize)
}
return i, nil
}
func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error {
cSumAt := cSumStart + cSumOffset
// The initial value at the checksum offset should be summed with the
// checksum we compute. This is typically the pseudo-header checksum.
initial := binary.BigEndian.Uint16(in[cSumAt:])
in[cSumAt], in[cSumAt+1] = 0, 0
binary.BigEndian.PutUint16(in[cSumAt:], ^checksumFold(in[cSumStart:], uint64(initial)))
return nil
}
// handleVirtioRead splits in into bufs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of bufs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
// This means CHECKSUM_PARTIAL in skb context. We are responsible
// for computing the checksum starting at hdr.csumStart and placing
// at hdr.csumOffset.
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
if err != nil {
return 0, err
}
}
if len(in) > len(bufs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:]))
}
n := copy(bufs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the TCP header length and add it onto
// csumStart, which is synonymous for IP header length.
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, bufs, sizes, offset)
}
func checksumNoFold(b []byte, initial uint64) uint64 {
return initial + uint64(clashtcpip.Sum(b))
}
func checksumFold(b []byte, initial uint64) uint16 {
ac := checksumNoFold(b, initial)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
ac = (ac >> 16) + (ac & 0xffff)
return uint16(ac)
}
func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
sum := checksumNoFold(srcAddr, 0)
sum = checksumNoFold(dstAddr, sum)
sum = checksumNoFold([]byte{0, protocol}, sum)
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
return checksumNoFold(tmp, sum)
}

View File

@@ -0,0 +1,5 @@
package tun
import E "github.com/sagernet/sing/common/exceptions"
var ErrTooManySegments = E.New("too many segments")

5
tun_nonlinux.go Normal file
View File

@@ -0,0 +1,5 @@
//go:build !linux
package tun
const OffloadOffset = 0

View File

@@ -2,13 +2,17 @@ package tun
import (
"context"
"net/netip"
"os"
"runtime"
"sort"
"strconv"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ranges"
"go4.org/netipx"
)
const (
@@ -96,3 +100,94 @@ func buildExcludedRanges(includeRanges []ranges.Range[uint32], excludeRanges []r
}
return ranges.Merge(uidRanges)
}
const autoRouteUseSubRanges = runtime.GOOS == "darwin"
func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) {
var routeRanges []netip.Prefix
if o.AutoRoute && len(o.Inet4Address) > 0 {
var inet4Ranges []netip.Prefix
if len(o.Inet4RouteAddress) > 0 {
inet4Ranges = o.Inet4RouteAddress
if runtime.GOOS == "darwin" {
for _, address := range o.Inet4Address {
if address.Bits() < 32 {
inet4Ranges = append(inet4Ranges, address.Masked())
}
}
}
} else if autoRouteUseSubRanges && !underNetworkExtension {
inet4Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
}
} else {
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if len(o.Inet4RouteExcludeAddress) == 0 {
routeRanges = append(routeRanges, inet4Ranges...)
} else {
var builder netipx.IPSetBuilder
for _, inet4Range := range inet4Ranges {
builder.AddPrefix(inet4Range)
}
for _, prefix := range o.Inet4RouteExcludeAddress {
builder.RemovePrefix(prefix)
}
resultSet, err := builder.IPSet()
if err != nil {
return nil, E.Cause(err, "build IPv4 route address")
}
routeRanges = append(routeRanges, resultSet.Prefixes()...)
}
}
if len(o.Inet6Address) > 0 {
var inet6Ranges []netip.Prefix
if len(o.Inet6RouteAddress) > 0 {
inet6Ranges = o.Inet6RouteAddress
if runtime.GOOS == "darwin" {
for _, address := range o.Inet6Address {
if address.Bits() < 32 {
inet6Ranges = append(inet6Ranges, address.Masked())
}
}
}
} else if autoRouteUseSubRanges && !underNetworkExtension {
inet6Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
}
} else {
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
if len(o.Inet6RouteExcludeAddress) == 0 {
routeRanges = append(routeRanges, inet6Ranges...)
} else {
var builder netipx.IPSetBuilder
for _, inet6Range := range inet6Ranges {
builder.AddPrefix(inet6Range)
}
for _, prefix := range o.Inet6RouteExcludeAddress {
builder.RemovePrefix(prefix)
}
resultSet, err := builder.IPSet()
if err != nil {
return nil, E.Cause(err, "build IPv6 route address")
}
routeRanges = append(routeRanges, resultSet.Prefixes()...)
}
}
return routeRanges, nil
}

View File

@@ -9,13 +9,15 @@ import (
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/sagernet/sing-tun/internal/winipcfg"
"github.com/sagernet/sing-tun/internal/winsys"
"github.com/sagernet/sing-tun/internal/wintun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/windnsapi"
@@ -32,7 +34,7 @@ type NativeTun struct {
rate rateJuggler
running sync.WaitGroup
closeOnce sync.Once
close int32
close atomic.Int32
fwpmSession uintptr
}
@@ -85,38 +87,22 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv6 dns")
}
}
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
_ = luid.DisableDNSRegistration()
}
if t.options.AutoRoute {
if len(t.options.Inet4Address) > 0 {
if len(t.options.Inet4RouteAddress) > 0 {
for _, addr := range t.options.Inet4RouteAddress {
err := luid.AddRoute(addr, netip.IPv4Unspecified(), 0)
if err != nil {
return E.Cause(err, "add ipv4 route: ", addr)
}
}
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0)
} else {
err := luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv4 route")
}
err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0)
}
}
if len(t.options.Inet6Address) > 0 {
if len(t.options.Inet6RouteAddress) > 0 {
for _, addr := range t.options.Inet6RouteAddress {
err := luid.AddRoute(addr, netip.IPv6Unspecified(), 0)
if err != nil {
return E.Cause(err, "add ipv6 route: ", addr)
}
}
} else {
err := luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv6 route")
}
}
}
err := windnsapi.FlushResolverCache()
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
@@ -348,13 +334,13 @@ func (t *NativeTun) ReadPacket() ([]byte, func(), error) {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return nil, nil, os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return nil, nil, os.ErrClosed
}
packet, err := t.session.ReceivePacket()
@@ -383,13 +369,13 @@ func (t *NativeTun) ReadFunc(block func(b []byte)) error {
t.running.Add(1)
defer t.running.Done()
retry:
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return os.ErrClosed
}
start := nanotime()
shouldSpin := atomic.LoadUint64(&t.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&t.rate.nextStartTime)) <= rateMeasurementGranularity*2
shouldSpin := t.rate.current.Load() >= spinloopRateThreshold && uint64(start-t.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return os.ErrClosed
}
packet, err := t.session.ReceivePacket()
@@ -419,7 +405,7 @@ retry:
func (t *NativeTun) Write(p []byte) (n int, err error) {
t.running.Add(1)
defer t.running.Done()
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return 0, os.ErrClosed
}
t.rate.update(uint64(len(p)))
@@ -441,7 +427,7 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
t.running.Add(1)
defer t.running.Done()
if atomic.LoadInt32(&t.close) == 1 {
if t.close.Load() == 1 {
return 0, os.ErrClosed
}
var packetSize int
@@ -467,10 +453,15 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
return 0, fmt.Errorf("write failed: %w", err)
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
defer buf.ReleaseMulti(buffers)
return common.Error(t.write(buf.ToSliceMulti(buffers)))
}
func (t *NativeTun) Close() error {
var err error
t.closeOnce.Do(func() {
atomic.StoreInt32(&t.close, 1)
t.close.Store(1)
windows.SetEvent(t.readWait)
t.running.Wait()
t.session.End()
@@ -500,24 +491,24 @@ func procyield(cycles uint32)
func nanotime() int64
type rateJuggler struct {
current uint64
nextByteCount uint64
nextStartTime int64
changing int32
current atomic.Uint64
nextByteCount atomic.Uint64
nextStartTime atomic.Int64
changing atomic.Int32
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
total := rate.nextByteCount.Add(packetLen)
period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
if !rate.changing.CompareAndSwap(0, 1) {
return
}
atomic.StoreInt64(&rate.nextStartTime, now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
atomic.StoreUint64(&rate.nextByteCount, 0)
atomic.StoreInt32(&rate.changing, 0)
rate.nextStartTime.Store(now)
rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
rate.nextByteCount.Store(0)
rate.changing.Store(0)
}
}

View File

@@ -3,10 +3,10 @@
package tun
import (
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
)
var _ GVisorTun = (*NativeTun)(nil)
@@ -35,7 +35,7 @@ func (e *WintunEndpoint) LinkAddress() tcpip.LinkAddress {
}
func (e *WintunEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityNone
return stack.CapabilityRXChecksumOffload
}
func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -51,16 +51,16 @@ func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
func (e *WintunEndpoint) dispatchLoop() {
for {
var buffer bufferv2.Buffer
var packetBuffer buffer.Buffer
err := e.tun.ReadFunc(func(b []byte) {
buffer = bufferv2.MakeWithData(b)
packetBuffer = buffer.MakeWithData(b)
})
if err != nil {
break
}
ihl, ok := buffer.PullUp(0, 1)
ihl, ok := packetBuffer.PullUp(0, 1)
if !ok {
buffer.Release()
packetBuffer.Release()
continue
}
var networkProtocol tcpip.NetworkProtocolNumber
@@ -70,12 +70,12 @@ func (e *WintunEndpoint) dispatchLoop() {
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
default:
e.tun.Write(buffer.Flatten())
buffer.Release()
e.tun.Write(packetBuffer.Flatten())
packetBuffer.Release()
continue
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer,
Payload: packetBuffer,
IsForwardedPacket: true,
})
dispatcher := e.dispatcher
@@ -102,6 +102,10 @@ func (e *WintunEndpoint) ARPHardwareType() header.ARPHardwareType {
func (e *WintunEndpoint) AddHeader(buffer *stack.PacketBuffer) {
}
func (e *WintunEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
return true
}
func (e *WintunEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {
var n int
for _, packet := range packetBufferList.AsSlice() {