Compare commits

..

49 Commits

Author SHA1 Message Date
世界
69c3b72eec FIx error handling for netlink banned in Android 2023-12-12 13:52:55 +08:00
世界
62f2d98190 Fix auto-route IPv6 on darwin 2023-12-11 21:37:40 +08:00
世界
d275b4a0fd Update dependencies 2023-12-11 21:37:40 +08:00
世界
46adeb9b5d Update dependencies 2023-12-04 21:00:08 +08:00
世界
f6ea97c5af Update gVisor to 20231113.0 2023-11-19 11:55:13 +08:00
世界
3fa4ee409a Add route exclude support 2023-11-16 18:27:36 +08:00
世界
958d6a25a4 Fix "Fix Linux IPv6 auto route rules" 2023-11-16 18:21:14 +08:00
世界
86322a3fe1 Update dependency 2023-11-14 19:28:42 +08:00
世界
78e0dfa18f Fix broadcast filter not applied to mixed stack 2023-11-14 17:16:24 +08:00
世界
ce9c864d89 system: Check UDP invalid packet 2023-11-14 17:16:24 +08:00
世界
da350ecc72 Add broadcast filter 2023-11-06 21:33:35 +08:00
世界
1a00992d06 Update dependencies 2023-11-05 16:08:15 +08:00
世界
150b116231 Add multicast filter 2023-11-04 08:01:41 +08:00
世界
b93db9639d Remove network requirement on start 2023-11-03 10:59:51 +08:00
i40e
efd9884154 Disable Windows DNS registration 2023-10-26 14:08:25 +08:00
世界
1a85bd3ef4 Add DefaultInterface func 2023-10-26 14:08:21 +08:00
世界
56a9b85cf5 Update gVisor to 20230814.0 2023-10-26 14:08:19 +08:00
世界
660222a0dd Fix Linux IPv6 auto route rules 2023-10-26 11:59:14 +08:00
世界
fee2614ae3 Update dependencies 2023-10-26 11:58:57 +08:00
世界
2b625a47c0 Update dependencies 2023-10-06 17:07:20 +08:00
世界
dcf7d50379 Fix gVisor UDP 6to4 check 2023-10-06 17:07:04 +08:00
世界
4979f75513 Update dependencies 2023-09-30 22:33:39 +08:00
世界
2a0a0ab228 android: Fix netlink check 2023-09-26 17:39:31 +08:00
世界
8adce0ea02 android: Check netlink available on monitor create 2023-09-25 17:15:15 +08:00
世界
b6d323004e Remove use of Write Unreachable as SendRejectionError panics when passing invalid packet 2023-09-22 11:50:04 +08:00
世界
e212724bac Update dependencies 2023-09-20 22:18:32 +08:00
世界
9c933ea553 Remove defaultInterfaceName when no route 2023-09-20 14:08:16 +08:00
世界
7545dc2d56 Fix darwin monitor socket leak 2023-08-21 14:55:22 +08:00
世界
db70908d61 Update dependencies 2023-08-20 17:19:22 +08:00
世界
824b903ebd Add [include/exclude]_interface iproute2 options 2023-08-20 17:18:29 +08:00
世界
10d98f2679 Add mixed stack 2023-08-12 19:38:06 +08:00
世界
aa8760b454 Add handshake interface support for gVisor UDP 2023-08-07 20:32:32 +08:00
世界
0a68b9f1d8 Fix monitor 2023-08-07 20:31:52 +08:00
世界
59b86002c4 Add no route event 2023-08-07 19:51:07 +08:00
世界
688d4da4b7 Fix gVisor UDP 2023-07-25 08:12:56 +08:00
世界
28db424ae8 Update dependencies 2023-07-23 14:16:08 +08:00
世界
bbf542f01a Improve gVisor UDP 2023-07-23 14:01:36 +08:00
世界
fd850d00e5 Fix buffer usage 2023-07-03 21:44:24 +08:00
世界
3b558f113c Update gVisor to 20230621.0 2023-06-27 11:12:05 +08:00
世界
d51abeb6c7 Prevent panic when write packet with bad address type 2023-06-21 13:27:17 +08:00
世界
323b9564f0 Update dependencies 2023-06-17 12:11:41 +08:00
世界
4bc8dc7f27 Configure default-route for systemd-resolved
Even though the documentation says this parameter doesn't matter, some people have reported that not configuring it can cause problems.
2023-06-12 19:28:29 +08:00
世界
41b2639e13 Update gVisor to 20230605.0-33-g8ec8dbe7e 2023-06-11 22:06:33 +08:00
世界
605266e65e Update gci usage 2023-06-10 08:50:20 +08:00
世界
e881f21013 Update gVisor to release-20230605.0-21-g457c1c36d 2023-06-10 08:45:55 +08:00
dyhkwong
b02f252916 Use api to create windows firewall rules 2023-05-20 12:11:00 +08:00
世界
91df97aee2 Fix macos monitor 2023-05-09 18:20:26 +08:00
世界
6999634511 Fix windows firewall for system stack 2023-05-09 12:12:00 +08:00
世界
209ec123ca Update gVisor to 20230417.0 2023-04-22 20:14:32 +08:00
43 changed files with 1532 additions and 1021 deletions

View File

@@ -1,6 +1,5 @@
#!/usr/bin/env bash
PROJECTS=$(dirname "$0")/../..
go get -x github.com/sagernet/sing@$(git -C $PROJECTS/sing rev-parse HEAD)
go get -x github.com/sagernet/$1@$(git -C $PROJECTS/$1 rev-parse HEAD)
go mod tidy

View File

@@ -1,7 +1,16 @@
build:
GOOS=darwin GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=ios GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=amd64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=386 go build -v -tags with_gvisor .
GOOS=linux GOARCH=arm go build -v -tags with_gvisor .
GOOS=windows GOARCH=amd64 go build -v -tags with_gvisor .
fmt:
@gofumpt -l -w .
@gofmt -s -w .
@gci write --custom-order -s "standard,prefix(github.com/sagernet/),default" .
@gci write --custom-order -s standard -s "prefix(github.com/sagernet/)" -s "default" .
fmt_install:
go install -v mvdan.cc/gofumpt@latest

17
go.mod
View File

@@ -3,17 +3,20 @@ module github.com/sagernet/sing-tun
go 1.18
require (
github.com/fsnotify/fsnotify v1.6.0
github.com/fsnotify/fsnotify v1.7.0
github.com/go-ole/go-ole v1.3.0
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61
github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97
github.com/sagernet/sing v0.2.4
golang.org/x/net v0.9.0
golang.org/x/sys v0.7.0
gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c
github.com/sagernet/sing v0.2.19
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/net v0.19.0
golang.org/x/sys v0.15.0
)
require (
github.com/google/btree v1.0.1 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 // indirect
golang.org/x/time v0.4.0 // indirect
)

36
go.sum
View File

@@ -1,24 +1,30 @@
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4=
github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms=
github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930 h1:dSPgjIw0CT6ISLeEh8Q20dZMBMFCcEceo23+LncRcNQ=
github.com/sagernet/gvisor v0.0.0-20231119034329-07cfb6aaf930/go.mod h1:JpKHkOYgh4wLwrX2BhH3ZIvCvazCkTnPeEcmigZJfHY=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.2.4 h1:gC8BR5sglbJZX23RtMyFa8EETP9YEUADhfbEzU1yVbo=
github.com/sagernet/sing v0.2.4/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/sagernet/sing v0.2.19 h1:Mdj/YJ5TtEyG+eIZaAlvX8j2cHxMN6eW4RF6Xh9iWyg=
github.com/sagernet/sing v0.2.19/go.mod h1:Ce5LNojQOgOiWhiD8pPD6E9H7e2KgtOe3Zxx4Ou5u80=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM=
golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c=
golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c h1:m5lcgWnL3OElQNVyp3qcncItJ2c0sQlSGjYK2+nJTA4=
gvisor.dev/gvisor v0.0.0-20220901235040-6ca97ef2ce1c/go.mod h1:TIvkJD0sxe8pIob3p6T8IzxXunlp6yfgktvTNp+DGNM=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.4.0 h1:Z81tqI5ddIoXDPvVQ7/7CC9TnLM7ubaFG2qXYd5BbYY=
golang.org/x/time v0.4.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=

View File

@@ -1,119 +0,0 @@
//go:build with_gvisor
package tun
import (
"context"
"math"
"net"
"net/netip"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort]
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
}
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.RemoteAddress)), id.RemotePort)
upstreamMetadata.Destination = M.SocksaddrFrom(M.AddrFromIP(net.IP(id.LocalAddress)), id.LocalPort)
var netProto tcpip.NetworkProtocolNumber
if upstreamMetadata.Source.IsIPv4() {
netProto = header.IPv4ProtocolNumber
} else {
netProto = header.IPv6ProtocolNumber
}
f.udpNat.NewPacket(
f.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(pkt.Data().AsRange().ToSlice()),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{f.stack, id.RemoteAddress, id.RemotePort, netProto}
},
)
return true
}
type UDPBackWriter struct {
stack *stack.Stack
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *UDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
tcpip.Address(destination.Addr.AsSlice()),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return wrapStackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: bufferv2.MakeWithData(buffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(header.ChecksumCombine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().AsRange().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return wrapStackError(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}

274
internal/winfw/winfw.go Normal file
View File

@@ -0,0 +1,274 @@
// Copyright (c) 2018 Samuel Melrose
// SPDX-License-Identifier: MIT
// https://github.com/iamacarpet/go-win64api/blob/ef6dbdd6db97301ae08a55eedea773476985a602/firewall.go
//go:build windows
package winfw
import (
"fmt"
"runtime"
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"github.com/scjalliance/comshim"
)
// Firewall related API constants.
const (
NET_FW_IP_PROTOCOL_TCP = 6
NET_FW_IP_PROTOCOL_UDP = 17
NET_FW_IP_PROTOCOL_ICMPv4 = 1
NET_FW_IP_PROTOCOL_ICMPv6 = 58
NET_FW_IP_PROTOCOL_ANY = 256
NET_FW_RULE_DIR_IN = 1
NET_FW_RULE_DIR_OUT = 2
NET_FW_ACTION_BLOCK = 0
NET_FW_ACTION_ALLOW = 1
// NET_FW_PROFILE2_CURRENT is not real API constant, just helper used in FW functions.
// It can mean one profile or multiple (even all) profiles. It depends on which profiles
// are currently in use. Every active interface can have it's own profile. F.e.: Public for Wifi,
// Domain for VPN, and Private for LAN. All at the same time.
NET_FW_PROFILE2_CURRENT = 0
NET_FW_PROFILE2_DOMAIN = 1
NET_FW_PROFILE2_PRIVATE = 2
NET_FW_PROFILE2_PUBLIC = 4
NET_FW_PROFILE2_ALL = 2147483647
)
// Firewall Rule Groups
// Use this magical strings instead of group names. It will work on all language Windows versions.
// You can find more string locations here:
// https://windows10dll.nirsoft.net/firewallapi_dll.html
const (
NET_FW_FILE_AND_PRINTER_SHARING = "@FirewallAPI.dll,-28502"
NET_FW_REMOTE_DESKTOP = "@FirewallAPI.dll,-28752"
)
// FWRule represents Firewall Rule.
type FWRule struct {
Name, Description, ApplicationName, ServiceName string
LocalPorts, RemotePorts string
// LocalAddresses, RemoteAddresses are always returned with netmask, f.e.:
// `10.10.1.1/255.255.255.0`
LocalAddresses, RemoteAddresses string
// ICMPTypesAndCodes is string. You can find define multiple codes separated by ":" (colon).
// Types are listed here:
// https://www.iana.org/assignments/icmp-parameters/icmp-parameters.xhtml
// So to allow ping set it to:
// "0"
ICMPTypesAndCodes string
Grouping string
// InterfaceTypes can be:
// "LAN", "Wireless", "RemoteAccess", "All"
// You can add multiple deviding with comma:
// "LAN, Wireless"
InterfaceTypes string
Protocol, Direction, Action, Profiles int32
Enabled, EdgeTraversal bool
}
// FirewallRuleAddAdvanced allows to modify almost all available FW Rule parameters.
// You probably do not want to use this, as function allows to create any rule, even opening all ports
// in given profile. So use with caution.
func FirewallRuleAddAdvanced(rule FWRule) (bool, error) {
return firewallRuleAdd(rule.Name, rule.Description, rule.Grouping, rule.ApplicationName, rule.ServiceName,
rule.LocalPorts, rule.RemotePorts, rule.LocalAddresses, rule.RemoteAddresses, rule.ICMPTypesAndCodes,
rule.Protocol, rule.Direction, rule.Action, rule.Profiles, rule.Enabled, rule.EdgeTraversal)
}
// firewallRuleAdd is universal function to add all kinds of rules.
func firewallRuleAdd(name, description, group, appPath, serviceName, ports, remotePorts, localAddresses, remoteAddresses, icmpTypes string, protocol, direction, action, profile int32, enabled, edgeTraversal bool) (bool, error) {
if name == "" {
return false, fmt.Errorf("empty FW Rule name, name is mandatory")
}
runtime.LockOSThread()
defer runtime.UnlockOSThread()
u, fwPolicy, err := firewallAPIInit()
if err != nil {
return false, err
}
defer firewallAPIRelease(u, fwPolicy)
if profile == NET_FW_PROFILE2_CURRENT {
currentProfiles, err := oleutil.GetProperty(fwPolicy, "CurrentProfileTypes")
if err != nil {
return false, fmt.Errorf("Failed to get CurrentProfiles: %s", err)
}
profile = currentProfiles.Value().(int32)
}
unknownRules, err := oleutil.GetProperty(fwPolicy, "Rules")
if err != nil {
return false, fmt.Errorf("Failed to get Rules: %s", err)
}
rules := unknownRules.ToIDispatch()
if ok, err := FirewallRuleExistsByName(rules, name); err != nil {
return false, fmt.Errorf("Error while checking rules for duplicate: %s", err)
} else if ok {
return false, nil
}
unknown2, err := oleutil.CreateObject("HNetCfg.FWRule")
if err != nil {
return false, fmt.Errorf("Error creating Rule object: %s", err)
}
defer unknown2.Release()
fwRule, err := unknown2.QueryInterface(ole.IID_IDispatch)
if err != nil {
return false, fmt.Errorf("Error creating Rule object (2): %s", err)
}
defer fwRule.Release()
if _, err := oleutil.PutProperty(fwRule, "Name", name); err != nil {
return false, fmt.Errorf("Error setting property (Name) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Description", description); err != nil {
return false, fmt.Errorf("Error setting property (Description) of Rule: %s", err)
}
if appPath != "" {
if _, err := oleutil.PutProperty(fwRule, "Applicationname", appPath); err != nil {
return false, fmt.Errorf("Error setting property (Applicationname) of Rule: %s", err)
}
}
if serviceName != "" {
if _, err := oleutil.PutProperty(fwRule, "ServiceName", serviceName); err != nil {
return false, fmt.Errorf("Error setting property (ServiceName) of Rule: %s", err)
}
}
if protocol != 0 {
if _, err := oleutil.PutProperty(fwRule, "Protocol", protocol); err != nil {
return false, fmt.Errorf("Error setting property (Protocol) of Rule: %s", err)
}
}
if icmpTypes != "" {
if _, err := oleutil.PutProperty(fwRule, "IcmpTypesAndCodes", icmpTypes); err != nil {
return false, fmt.Errorf("Error setting property (IcmpTypesAndCodes) of Rule: %s", err)
}
}
if ports != "" {
if _, err := oleutil.PutProperty(fwRule, "LocalPorts", ports); err != nil {
return false, fmt.Errorf("Error setting property (LocalPorts) of Rule: %s", err)
}
}
if remotePorts != "" {
if _, err := oleutil.PutProperty(fwRule, "RemotePorts", remotePorts); err != nil {
return false, fmt.Errorf("Error setting property (RemotePorts) of Rule: %s", err)
}
}
if localAddresses != "" {
if _, err := oleutil.PutProperty(fwRule, "LocalAddresses", localAddresses); err != nil {
return false, fmt.Errorf("Error setting property (LocalAddresses) of Rule: %s", err)
}
}
if remoteAddresses != "" {
if _, err := oleutil.PutProperty(fwRule, "RemoteAddresses", remoteAddresses); err != nil {
return false, fmt.Errorf("Error setting property (RemoteAddresses) of Rule: %s", err)
}
}
if direction != 0 {
if _, err := oleutil.PutProperty(fwRule, "Direction", direction); err != nil {
return false, fmt.Errorf("Error setting property (Direction) of Rule: %s", err)
}
}
if _, err := oleutil.PutProperty(fwRule, "Enabled", enabled); err != nil {
return false, fmt.Errorf("Error setting property (Enabled) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Grouping", group); err != nil {
return false, fmt.Errorf("Error setting property (Grouping) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Profiles", profile); err != nil {
return false, fmt.Errorf("Error setting property (Profiles) of Rule: %s", err)
}
if _, err := oleutil.PutProperty(fwRule, "Action", action); err != nil {
return false, fmt.Errorf("Error setting property (Action) of Rule: %s", err)
}
if edgeTraversal {
if _, err := oleutil.PutProperty(fwRule, "EdgeTraversal", edgeTraversal); err != nil {
return false, fmt.Errorf("Error setting property (EdgeTraversal) of Rule: %s", err)
}
}
if _, err := oleutil.CallMethod(rules, "Add", fwRule); err != nil {
return false, fmt.Errorf("Error adding Rule: %s", err)
}
return true, nil
}
func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
enumProperty, err := rules.GetProperty("_NewEnum")
if err != nil {
return false, fmt.Errorf("Failed to get enumeration property on Rules: %s", err)
}
defer enumProperty.Clear()
enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
if err != nil {
return false, fmt.Errorf("Failed to cast enum to correct type: %s", err)
}
if enum == nil {
return false, fmt.Errorf("can't get IEnumVARIANT, enum is nil")
}
for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
if err != nil {
return false, fmt.Errorf("Failed to seek next Rule item: %s", err)
}
t, err := func() (bool, error) {
item := itemRaw.ToIDispatch()
defer item.Release()
if item, err := oleutil.GetProperty(item, "Name"); err != nil {
return false, fmt.Errorf("Failed to get Property (Name) of Rule")
} else if item.ToString() == name {
return true, nil
}
return false, nil
}()
if err != nil {
return false, err
} else if t {
return true, nil
}
}
return false, nil
}
// firewallAPIInit initialize common fw api.
// then:
// dispatch firewallAPIRelease(u, fwp)
func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
comshim.Add(1)
unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
if err != nil {
return nil, nil, fmt.Errorf("Failed to create FwPolicy Object: %s", err)
}
fwPolicy, err := unknown.QueryInterface(ole.IID_IDispatch)
if err != nil {
unknown.Release()
return nil, nil, fmt.Errorf("Failed to create FwPolicy Object (2): %s", err)
}
return unknown, fwPolicy, nil
}
// firewallAPIRelease cleans memory.
func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
fwp.Release()
u.Release()
comshim.Done()
}

View File

@@ -385,3 +385,25 @@ func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []st
func (luid LUID) FlushDNS(family AddressFamily) error {
return luid.SetDNS(family, nil, nil)
}
func (luid LUID) DisableDNSRegistration() error {
guid, err := luid.GUID()
if err != nil {
return err
}
dnsInterfaceSettings := &DnsInterfaceSettings{
Version: DnsInterfaceSettingsVersion1,
Flags: DnsInterfaceSettingsFlagRegistrationEnabled,
RegistrationEnabled: 0,
}
// For >= Windows 10 1809
err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
return err
}
// For < Windows 10 1809
return luid.fallbackDisableDNSRegistration()
}

View File

@@ -51,10 +51,11 @@ func runNetsh(cmds []string) error {
}
const (
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateDisableRegistration = "interface ipv6 set dnsservers name=%d register=none"
)
func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
@@ -106,3 +107,13 @@ func (luid LUID) fallbackSetDNSDomain(domain string) error {
key.Close()
return err
}
func (luid LUID) fallbackDisableDNSRegistration() error {
// the DNS registration setting is shared for both IPv4 and IPv6
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return err
}
cmd := fmt.Sprintf(netshCmdTemplateDisableRegistration, ipif.InterfaceIndex)
return runNetsh([]string{cmd})
}

View File

@@ -10,13 +10,14 @@ import (
var ErrNoRoute = E.New("no route to internet")
type (
NetworkUpdateCallback = func() error
DefaultInterfaceUpdateCallback = func(event int) error
NetworkUpdateCallback = func()
DefaultInterfaceUpdateCallback = func(event int)
)
const (
EventInterfaceUpdate = 1
EventAndroidVPNUpdate = 2
EventNoRoute = 4
)
type NetworkUpdateMonitor interface {
@@ -24,7 +25,6 @@ type NetworkUpdateMonitor interface {
Close() error
RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback]
UnregisterCallback(element *list.Element[NetworkUpdateCallback])
E.Handler
}
type DefaultInterfaceMonitor interface {
@@ -32,6 +32,7 @@ type DefaultInterfaceMonitor interface {
Close() error
DefaultInterfaceName(destination netip.Addr) string
DefaultInterfaceIndex(destination netip.Addr) int
DefaultInterface(destination netip.Addr) (string, int)
OverrideAndroidVPN() bool
AndroidVPNEnabled() bool
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
@@ -39,5 +40,6 @@ type DefaultInterfaceMonitor interface {
}
type DefaultInterfaceMonitorOptions struct {
OverrideAndroidVPN bool
OverrideAndroidVPN bool
UnderNetworkExtension bool
}

View File

@@ -1,16 +1,15 @@
package tun
import (
"context"
"net"
"net/netip"
"os"
"sync"
"syscall"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/net/route"
@@ -18,61 +17,78 @@ import (
)
type networkUpdateMonitor struct {
errorHandler E.Handler
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
routeSocket *os.File
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
routeSocketFile *os.File
closeOnce sync.Once
done chan struct{}
logger logger.Logger
}
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
errorHandler: errorHandler,
logger: logger,
done: make(chan struct{}),
}, nil
}
func (m *networkUpdateMonitor) Start() error {
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
if err != nil {
return err
}
err = unix.SetNonblock(routeSocket, true)
if err != nil {
return err
}
m.routeSocket = os.NewFile(uintptr(routeSocket), "route")
go m.loopUpdate()
return nil
}
func (m *networkUpdateMonitor) loopUpdate() {
rawConn, err := m.routeSocket.SyscallConn()
for {
select {
case <-m.done:
return
case <-time.After(time.Second):
}
err := m.loopUpdate0()
if err != nil {
m.logger.Error("listen network update: ", err)
return
}
}
}
func (m *networkUpdateMonitor) loopUpdate0() error {
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
if err != nil {
return err
}
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
m.routeSocketFile = routeSocketFile
m.loopUpdate1(routeSocketFile)
return nil
}
func (m *networkUpdateMonitor) loopUpdate1(routeSocketFile *os.File) {
defer routeSocketFile.Close()
buffer := buf.NewPacket()
defer buffer.Release()
n, err := routeSocketFile.Read(buffer.FreeBytes())
if err != nil {
m.errorHandler.NewError(context.Background(), E.Cause(err, "create raw route connection"))
return
}
for {
var innerErr error
err = rawConn.Read(func(fd uintptr) (done bool) {
var msg [2048]byte
_, innerErr = unix.Read(int(fd), msg[:])
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
if err != nil {
break
}
m.emit()
buffer.Truncate(n)
messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.Bytes())
if err != nil {
return
}
if err != syscall.EAGAIN {
m.errorHandler.NewError(context.Background(), E.Cause(err, "read route message"))
for _, message := range messages {
if _, isRouteMessage := message.(*route.RouteMessage); isRouteMessage {
m.emit()
return
}
}
}
func (m *networkUpdateMonitor) Close() error {
return common.Close(common.PtrOrNil(m.routeSocket))
m.closeOnce.Do(func() {
close(m.done)
})
return nil
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
@@ -116,17 +132,22 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
continue
}
if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
continue
// continue
}
defaultInterface = routeInterface
break
}
if defaultInterface == nil {
defaultInterface, err = getDefaultInterfaceBySocket()
if err != nil {
return err
if m.options.UnderNetworkExtension {
defaultInterface, err = getDefaultInterfaceBySocket()
if err != nil {
return err
}
}
}
if defaultInterface == nil {
return ErrNoRoute
}
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceIndex = defaultInterface.Index

View File

@@ -2,30 +2,55 @@ package tun
import (
"os"
"runtime"
"sync"
"github.com/sagernet/netlink"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/unix"
)
type networkUpdateMonitor struct {
routeUpdate chan netlink.RouteUpdate
linkUpdate chan netlink.LinkUpdate
close chan struct{}
errorHandler E.Handler
routeUpdate chan netlink.RouteUpdate
linkUpdate chan netlink.LinkUpdate
close chan struct{}
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
logger logger.Logger
}
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
routeUpdate: make(chan netlink.RouteUpdate, 2),
linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}),
errorHandler: errorHandler,
}, nil
var ErrNetlinkBanned = E.New(
"netlink socket in Android is banned by Google, " +
"use the root or system (ADB) user to run sing-box, " +
"or switch to the sing-box Adnroid graphical interface client",
)
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
monitor := &networkUpdateMonitor{
routeUpdate: make(chan netlink.RouteUpdate, 2),
linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}),
logger: logger,
}
// check is netlink banned by google
if runtime.GOOS == "android" {
netlinkSocket, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_ROUTE)
if err != nil {
return nil, ErrNetlinkBanned
}
err = unix.Bind(netlinkSocket, &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
})
unix.Close(netlinkSocket)
if err != nil {
return nil, ErrNetlinkBanned
}
}
return monitor, nil
}
func (m *networkUpdateMonitor) Start() error {

View File

@@ -4,7 +4,6 @@ package tun
import (
"github.com/sagernet/netlink"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/sys/unix"
)
@@ -37,5 +36,5 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
m.emit(EventInterfaceUpdate)
return nil
}
return E.New("no route to internet")
return ErrNoRoute
}

View File

@@ -5,13 +5,13 @@ package tun
import (
"os"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
return nil, os.ErrInvalid
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
return nil, os.ErrInvalid
}

View File

@@ -3,14 +3,14 @@
package tun
import (
"context"
"errors"
"net"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
)
@@ -32,17 +32,10 @@ func (m *networkUpdateMonitor) emit() {
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback()
if err != nil {
m.NewError(context.Background(), err)
}
callback()
}
}
func (m *networkUpdateMonitor) NewError(ctx context.Context, err error) {
m.errorHandler.NewError(ctx, err)
}
type defaultInterfaceMonitor struct {
options DefaultInterfaceMonitorOptions
networkAddresses []networkAddress
@@ -53,6 +46,7 @@ type defaultInterfaceMonitor struct {
element *list.Element[NetworkUpdateCallback]
access sync.Mutex
callbacks list.List[DefaultInterfaceUpdateCallback]
logger logger.Logger
}
type networkAddress struct {
@@ -61,30 +55,33 @@ type networkAddress struct {
addresses []netip.Prefix
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
return &defaultInterfaceMonitor{
options: options,
networkMonitor: networkMonitor,
defaultInterfaceIndex: -1,
logger: logger,
}, nil
}
func (m *defaultInterfaceMonitor) Start() error {
err := m.checkUpdate()
if err != nil {
m.networkMonitor.NewError(context.Background(), err)
}
_ = m.checkUpdate()
m.element = m.networkMonitor.RegisterCallback(m.delayCheckUpdate)
return nil
}
func (m *defaultInterfaceMonitor) delayCheckUpdate() error {
func (m *defaultInterfaceMonitor) delayCheckUpdate() {
time.Sleep(time.Second)
err := m.updateInterfaces()
if err != nil {
m.networkMonitor.NewError(context.Background(), E.Cause(err, "update interfaces"))
m.logger.Error("update interfaces: ", err)
}
err = m.checkUpdate()
if errors.Is(err, ErrNoRoute) {
m.defaultInterfaceName = ""
m.defaultInterfaceIndex = -1
m.emit(EventNoRoute)
}
return m.checkUpdate()
}
func (m *defaultInterfaceMonitor) updateInterfaces() error {
@@ -150,6 +147,20 @@ func (m *defaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr)
return m.defaultInterfaceIndex
}
func (m *defaultInterfaceMonitor) DefaultInterface(destination netip.Addr) (string, int) {
for _, address := range m.networkAddresses {
for _, prefix := range address.addresses {
if prefix.Contains(destination) {
return address.interfaceName, address.interfaceIndex
}
}
}
if m.defaultInterfaceIndex == -1 {
m.checkUpdate()
}
return m.defaultInterfaceName, m.defaultInterfaceIndex
}
func (m *defaultInterfaceMonitor) OverrideAndroidVPN() bool {
return m.options.OverrideAndroidVPN
}
@@ -175,9 +186,6 @@ func (m *defaultInterfaceMonitor) emit(event int) {
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
err := callback(event)
if err != nil {
m.networkMonitor.NewError(context.Background(), err)
}
callback(event)
}
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/sagernet/sing-tun/internal/winipcfg"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/windows"
@@ -17,11 +18,12 @@ type networkUpdateMonitor struct {
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
logger logger.Logger
}
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
errorHandler: errorHandler,
logger: logger,
}, nil
}

View File

@@ -1,92 +0,0 @@
package tun
import (
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
)
type ActionType = uint8
const (
ActionTypeUnknown ActionType = iota
ActionTypeReturn
ActionTypeBlock
ActionTypeDirect
)
func ParseActionType(action string) (ActionType, error) {
switch action {
case "return":
return ActionTypeReturn, nil
case "block":
return ActionTypeBlock, nil
case "direct":
return ActionTypeDirect, nil
default:
return 0, E.New("unknown action: ", action)
}
}
func ActionTypeName(actionType ActionType) (string, error) {
switch actionType {
case ActionTypeUnknown:
return "", nil
case ActionTypeReturn:
return "return", nil
case ActionTypeBlock:
return "block", nil
case ActionTypeDirect:
return "direct", nil
default:
return "", E.New("unknown action: ", actionType)
}
}
type RouteSession struct {
IPVersion uint8
Network uint8
Source netip.AddrPort
Destination netip.AddrPort
}
type RouteContext interface {
WritePacket(packet []byte) error
}
type Router interface {
RouteConnection(session RouteSession, context RouteContext) RouteAction
}
type RouteAction interface {
ActionType() ActionType
Timeout() bool
}
type ActionReturn struct{}
func (r *ActionReturn) ActionType() ActionType {
return ActionTypeReturn
}
func (r *ActionReturn) Timeout() bool {
return false
}
type ActionBlock struct{}
func (r *ActionBlock) ActionType() ActionType {
return ActionTypeBlock
}
func (r *ActionBlock) Timeout() bool {
return false
}
type ActionDirect struct {
DirectDestination
}
func (r *ActionDirect) ActionType() ActionType {
return ActionTypeDirect
}

View File

@@ -1,16 +0,0 @@
//go:build with_gvisor
package tun
import (
"github.com/sagernet/sing/common/buf"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
type DirectDestination interface {
WritePacket(buffer *buf.Buffer) error
WritePacketBuffer(buffer *stack.PacketBuffer) error
Close() error
Timeout() bool
}

View File

@@ -1,32 +0,0 @@
package tun
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/cache"
)
type RouteMapping struct {
status *cache.LruCache[RouteSession, RouteAction]
}
func NewRouteMapping(maxAge int64) *RouteMapping {
return &RouteMapping{
status: cache.New(
cache.WithAge[RouteSession, RouteAction](maxAge),
cache.WithUpdateAgeOnGet[RouteSession, RouteAction](),
cache.WithEvict[RouteSession, RouteAction](func(key RouteSession, conn RouteAction) {
common.Close(conn)
}),
),
}
}
func (m *RouteMapping) Lookup(session RouteSession, constructor func() RouteAction) RouteAction {
action, _ := m.status.LoadOrStore(session, constructor)
if action.Timeout() {
common.Close(action)
action = constructor()
m.status.Store(session, action)
}
return action
}

View File

@@ -1,119 +0,0 @@
package tun
import (
"net/netip"
"sync"
"github.com/sagernet/sing-tun/internal/clashtcpip"
)
type NatMapping struct {
access sync.RWMutex
sessions map[RouteSession]RouteContext
ipRewrite bool
}
func NewNatMapping(ipRewrite bool) *NatMapping {
return &NatMapping{
sessions: make(map[RouteSession]RouteContext),
ipRewrite: ipRewrite,
}
}
func (m *NatMapping) CreateSession(session RouteSession, context RouteContext) {
if m.ipRewrite {
session.Source = netip.AddrPort{}
}
m.access.Lock()
m.sessions[session] = context
m.access.Unlock()
}
func (m *NatMapping) DeleteSession(session RouteSession) {
if m.ipRewrite {
session.Source = netip.AddrPort{}
}
m.access.Lock()
delete(m.sessions, session)
m.access.Unlock()
}
func (m *NatMapping) WritePacket(packet []byte) (bool, error) {
var routeSession RouteSession
var ipHdr clashtcpip.IP
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
routeSession.IPVersion = 4
ipHdr = clashtcpip.IPv4Packet(packet)
case 6:
routeSession.IPVersion = 6
ipHdr = clashtcpip.IPv6Packet(packet)
default:
return false, nil
}
routeSession.Network = ipHdr.Protocol()
switch routeSession.Network {
case clashtcpip.TCP:
tcpHdr := clashtcpip.TCPPacket(ipHdr.Payload())
routeSession.Destination = netip.AddrPortFrom(ipHdr.SourceIP(), tcpHdr.SourcePort())
if !m.ipRewrite {
routeSession.Source = netip.AddrPortFrom(ipHdr.DestinationIP(), tcpHdr.DestinationPort())
}
case clashtcpip.UDP:
udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
routeSession.Destination = netip.AddrPortFrom(ipHdr.SourceIP(), udpHdr.SourcePort())
if !m.ipRewrite {
routeSession.Source = netip.AddrPortFrom(ipHdr.DestinationIP(), udpHdr.DestinationPort())
}
default:
routeSession.Destination = netip.AddrPortFrom(ipHdr.SourceIP(), 0)
if !m.ipRewrite {
routeSession.Source = netip.AddrPortFrom(ipHdr.DestinationIP(), 0)
}
}
m.access.RLock()
context, loaded := m.sessions[routeSession]
m.access.RUnlock()
if !loaded {
return false, nil
}
return true, context.WritePacket(packet)
}
type NatWriter struct {
inet4Address netip.Addr
inet6Address netip.Addr
}
func NewNatWriter(inet4Address netip.Addr, inet6Address netip.Addr) *NatWriter {
return &NatWriter{
inet4Address: inet4Address,
inet6Address: inet6Address,
}
}
func (w *NatWriter) RewritePacket(packet []byte) {
var ipHdr clashtcpip.IP
var bindAddr netip.Addr
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
ipHdr = clashtcpip.IPv4Packet(packet)
bindAddr = w.inet4Address
case 6:
ipHdr = clashtcpip.IPv6Packet(packet)
bindAddr = w.inet6Address
default:
return
}
ipHdr.SetSourceIP(bindAddr)
switch ipHdr.Protocol() {
case clashtcpip.TCP:
tcpHdr := clashtcpip.TCPPacket(ipHdr.Payload())
tcpHdr.ResetChecksum(ipHdr.PseudoSum())
case clashtcpip.UDP:
udpHdr := clashtcpip.UDPPacket(ipHdr.Payload())
udpHdr.ResetChecksum(ipHdr.PseudoSum())
default:
}
ipHdr.ResetChecksum()
}

View File

@@ -1,41 +0,0 @@
//go:build with_gvisor
package tun
import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
func (w *NatWriter) RewritePacketBuffer(packetBuffer *stack.PacketBuffer) {
var bindAddr tcpip.Address
if packetBuffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
bindAddr = tcpip.Address(w.inet4Address.AsSlice())
} else {
bindAddr = tcpip.Address(w.inet6Address.AsSlice())
}
var ipHdr header.Network
switch packetBuffer.NetworkProtocolNumber {
case header.IPv4ProtocolNumber:
ipHdr = header.IPv4(packetBuffer.NetworkHeader().Slice())
case header.IPv6ProtocolNumber:
ipHdr = header.IPv6(packetBuffer.NetworkHeader().Slice())
default:
return
}
oldAddr := ipHdr.SourceAddress()
if checksumHdr, needChecksum := ipHdr.(header.ChecksummableNetwork); needChecksum {
checksumHdr.SetSourceAddressWithChecksumUpdate(bindAddr)
} else {
ipHdr.SetSourceAddress(bindAddr)
}
switch packetBuffer.TransportProtocolNumber {
case header.TCPProtocolNumber:
tcpHdr := header.TCP(packetBuffer.TransportHeader().Slice())
tcpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
case header.UDPProtocolNumber:
udpHdr := header.UDP(packetBuffer.TransportHeader().Slice())
udpHdr.UpdateChecksumPseudoHeaderAddress(oldAddr, bindAddr, true)
}
}

View File

@@ -1,11 +0,0 @@
//go:build !with_gvisor
package tun
import "github.com/sagernet/sing/common/buf"
type DirectDestination interface {
WritePacket(buffer *buf.Buffer) error
Close() error
Timeout() bool
}

View File

@@ -2,6 +2,8 @@ package tun
import (
"context"
"encoding/binary"
"net"
"net/netip"
"github.com/sagernet/sing/common/control"
@@ -23,7 +25,6 @@ type StackOptions struct {
Inet6Address []netip.Prefix
EndpointIndependentNat bool
UDPTimeout int64
Router Router
Handler Handler
Logger logger.Logger
ForwarderBindInterface bool
@@ -36,9 +37,15 @@ func NewStack(
) (Stack, error) {
switch stack {
case "":
return NewSystem(options)
if WithGVisor {
return NewMixed(options)
} else {
return NewSystem(options)
}
case "gvisor":
return NewGVisor(options)
case "mixed":
return NewMixed(options)
case "system":
return NewSystem(options)
case "lwip":
@@ -47,3 +54,13 @@ func NewStack(
return nil, E.New("unknown stack: ", stack)
}
}
func BroadcastAddr(inet4Address []netip.Prefix) netip.Addr {
if len(inet4Address) == 0 {
return netip.Addr{}
}
prefix := inet4Address[0]
var broadcastAddr [4]byte
binary.BigEndian.PutUint32(broadcastAddr[:], binary.BigEndian.Uint32(prefix.Masked().Addr().AsSlice())|^binary.BigEndian.Uint32(net.CIDRMask(prefix.Bits(), 32)))
return netip.AddrFrom4(broadcastAddr)
}

View File

@@ -4,26 +4,24 @@ package tun
import (
"context"
"net"
"syscall"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
const WithGVisor = true
@@ -36,12 +34,11 @@ type GVisor struct {
tunMtu uint32
endpointIndependentNat bool
udpTimeout int64
router Router
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
routeMapping *RouteMapping
}
type GVisorTun interface {
@@ -63,13 +60,10 @@ func NewGVisor(
tunMtu: options.MTU,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
router: options.Router,
broadcastAddr: BroadcastAddr(options.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
if gStack.router != nil {
gStack.routeMapping = NewRouteMapping(options.UDPTimeout)
}
return gStack, nil
}
@@ -78,44 +72,11 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint)
if tErr != nil {
return E.New("create nic: ", wrapStackError(tErr))
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun.CreateVectorisedWriter()}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
handshakeCtx, cancel := context.WithCancel(context.Background())
@@ -155,44 +116,7 @@ func (t *GVisor) Start() error {
}
}()
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
if t.router != nil {
var routeSession RouteSession
routeSession.Network = syscall.IPPROTO_TCP
var ipHdr header.Network
if buffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
routeSession.IPVersion = 4
ipHdr = header.IPv4(buffer.NetworkHeader().Slice())
} else {
routeSession.IPVersion = 6
ipHdr = header.IPv6(buffer.NetworkHeader().Slice())
}
tcpHdr := header.TCP(buffer.TransportHeader().Slice())
routeSession.Source = M.AddrPortFrom(net.IP(ipHdr.SourceAddress()), tcpHdr.SourcePort())
routeSession.Destination = M.AddrPortFrom(net.IP(ipHdr.DestinationAddress()), tcpHdr.DestinationPort())
action := t.routeMapping.Lookup(routeSession, func() RouteAction {
if routeSession.IPVersion == 4 {
return t.router.RouteConnection(routeSession, &systemTCPDirectPacketWriter4{t.tun, routeSession.Source})
} else {
return t.router.RouteConnection(routeSession, &systemTCPDirectPacketWriter6{t.tun, routeSession.Source})
}
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return true
case *ActionDirect:
buffer.IncRef()
err = actionType.WritePacketBuffer(buffer)
if err != nil {
t.logger.Trace("route gvisor tcp packet: ", err)
}
return true
}
}
return tcpForwarder.HandlePacket(id, buffer)
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
if !t.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
@@ -207,54 +131,19 @@ func (t *GVisor) Start() error {
endpoint.Abort()
return
}
gConn := &gUDPConn{UDPConn: udpConn}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(&gUDPConn{udpConn}), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, func(id stack.TransportEndpointID, buffer *stack.PacketBuffer) bool {
if t.router != nil {
var routeSession RouteSession
routeSession.Network = syscall.IPPROTO_UDP
var ipHdr header.Network
if buffer.NetworkProtocolNumber == header.IPv4ProtocolNumber {
routeSession.IPVersion = 4
ipHdr = header.IPv4(buffer.NetworkHeader().Slice())
} else {
routeSession.IPVersion = 6
ipHdr = header.IPv6(buffer.NetworkHeader().Slice())
}
udpHdr := header.UDP(buffer.TransportHeader().Slice())
routeSession.Source = M.AddrPortFrom(net.IP(ipHdr.SourceAddress()), udpHdr.SourcePort())
routeSession.Destination = M.AddrPortFrom(net.IP(ipHdr.DestinationAddress()), udpHdr.DestinationPort())
action := t.routeMapping.Lookup(routeSession, func() RouteAction {
if routeSession.IPVersion == 4 {
return t.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter4{t.tun, routeSession.Source})
} else {
return t.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter6{t.tun, routeSession.Source})
}
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return true
case *ActionDirect:
buffer.IncRef()
err = actionType.WritePacketBuffer(buffer)
if err != nil {
t.logger.Trace("route gvisor udp packet: ", err)
}
return true
}
}
return udpForwarder.HandlePacket(id, buffer)
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
}
@@ -272,3 +161,60 @@ func (t *GVisor) Close() error {
}
return nil
}
func AddressFromAddr(destination netip.Addr) tcpip.Address {
if destination.Is6() {
return tcpip.AddrFrom16(destination.As16())
} else {
return tcpip.AddrFrom4(destination.As4())
}
}
func AddrFromAddress(address tcpip.Address) netip.Addr {
if address.Len() == 16 {
return netip.AddrFrom16(address.As16())
} else {
return netip.AddrFrom4(address.As4())
}
}
func newGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, ep)
if tErr != nil {
return nil, E.New("create nic: ", wrapStackError(tErr))
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
return ipStack, nil
}

View File

@@ -5,10 +5,9 @@ package tun
import (
"net"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
E "github.com/sagernet/sing/common/exceptions"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
type gTCPConn struct {
@@ -28,28 +27,6 @@ func (c *gTCPConn) Write(b []byte) (n int, err error) {
return
}
type gUDPConn struct {
*gonet.UDPConn
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func wrapStackError(err tcpip.Error) error {
switch err.(type) {
case *tcpip.ErrClosedForSend,

56
stack_gvisor_filter.go Normal file
View File

@@ -0,0 +1,56 @@
//go:build with_gvisor
package tun
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
type LinkEndpointFilter struct {
stack.LinkEndpoint
BroadcastAddress netip.Addr
Writer N.VectorisedWriter
}
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.BroadcastAddress, w.Writer})
}
var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
type networkDispatcherFilter struct {
stack.NetworkDispatcher
broadcastAddress netip.Addr
writer N.VectorisedWriter
}
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
var network header.Network
if protocol == header.IPv4ProtocolNumber {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv4MinimumSize); loaded {
network = header.IPv4(headerPackets)
}
} else {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv6MinimumSize); loaded {
network = header.IPv6(headerPackets)
}
}
if network == nil {
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
return
}
destination := AddrFromAddress(network.DestinationAddress())
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
return
}
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
}

View File

@@ -5,7 +5,7 @@ package tun
import (
"time"
gLog "gvisor.dev/gvisor/pkg/log"
gLog "github.com/sagernet/gvisor/pkg/log"
)
func init() {

View File

@@ -13,3 +13,9 @@ func NewGVisor(
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}
func NewMixed(
options StackOptions,
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}

221
stack_gvisor_udp.go Normal file
View File

@@ -0,0 +1,221 @@
//go:build with_gvisor
package tun
import (
"context"
"errors"
"math"
"net/netip"
"os"
"sync"
"syscall"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort]
// cache
cacheProto tcpip.NetworkProtocolNumber
cacheID stack.TransportEndpointID
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
}
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
if upstreamMetadata.Source.IsIPv4() {
f.cacheProto = header.IPv4ProtocolNumber
} else {
f.cacheProto = header.IPv6ProtocolNumber
}
gBuffer := pkt.Data().ToBuffer()
sBuffer := buf.NewSize(int(gBuffer.Size()))
gBuffer.Apply(func(view *buffer.View) {
sBuffer.Write(view.AsSlice())
})
f.cacheID = id
f.udpNat.NewPacket(
f.ctx,
upstreamMetadata.Source.AddrPort(),
sBuffer,
upstreamMetadata,
f.newUDPConn,
)
return true
}
func (f *UDPForwarder) newUDPConn(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{
stack: f.stack,
source: f.cacheID.RemoteAddress,
sourcePort: f.cacheID.RemotePort,
sourceNetwork: f.cacheProto,
}
}
type UDPBackWriter struct {
access sync.Mutex
stack *stack.Stack
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
packet stack.PacketBufferPtr
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
if !destination.IsIP() {
return E.Cause(os.ErrInvalid, "invalid destination")
} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
return E.New("send IPv6 packet to IPv4 connection")
}
defer packetBuffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return wrapStackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(checksum.Combine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return wrapStackError(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}
type gRequest struct {
stack *stack.Stack
id stack.TransportEndpointID
pkt stack.PacketBufferPtr
}
type gUDPConn struct {
*gonet.UDPConn
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Close() error {
return c.UDPConn.Close()
}
func gWriteUnreachable(gStack *stack.Stack, packet stack.PacketBufferPtr, err error) (retErr error) {
if errors.Is(err, syscall.ENETUNREACH) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPNetUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
}
} else if errors.Is(err, syscall.EHOSTUNREACH) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPHostUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPNoRoute)
}
} else if errors.Is(err, syscall.ECONNREFUSED) {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gWriteUnreachable4(gStack, packet, stack.RejectIPv4WithICMPPortUnreachable)
} else {
return gWriteUnreachable6(gStack, packet, stack.RejectIPv6WithICMPPortUnreachable)
}
}
return nil
}
func gWriteUnreachable4(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv4WithICMPType) error {
err := gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, icmpCode, true)
if err != nil {
return wrapStackError(err)
}
return nil
}
func gWriteUnreachable6(gStack *stack.Stack, packet stack.PacketBufferPtr, icmpCode stack.RejectIPv6WithICMPType) error {
err := gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, icmpCode, true)
if err != nil {
return wrapStackError(err)
}
return nil
}

View File

@@ -52,18 +52,13 @@ func (l *LWIP) loopIn() {
l.loopInWintun(winTun)
return
}
mtu := int(l.tunMtu) + PacketOffset
_buffer := buf.StackNewSize(mtu)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
data := buffer.FreeBytes()
buffer := make([]byte, int(l.tunMtu)+PacketOffset)
for {
n, err := l.tun.Read(data)
n, err := l.tun.Read(buffer)
if err != nil {
return
}
_, err = l.stack.Write(data[PacketOffset:n])
_, err = l.stack.Write(buffer[PacketOffset:n])
if err != nil {
if err.Error() == "stack closed" {
return

208
stack_mixed.go Normal file
View File

@@ -0,0 +1,208 @@
//go:build with_gvisor
package tun
import (
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun/internal/clashtcpip"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type Mixed struct {
*System
writer N.VectorisedWriter
endpointIndependentNat bool
stack *stack.Stack
endpoint *channel.Endpoint
}
func NewMixed(
options StackOptions,
) (Stack, error) {
system, err := NewSystem(options)
if err != nil {
return nil, err
}
return &Mixed{
System: system.(*System),
writer: options.Tun.CreateVectorisedWriter(),
endpointIndependentNat: options.EndpointIndependentNat,
}, nil
}
func (m *Mixed) Start() error {
err := m.System.start()
if err != nil {
return err
}
endpoint := channel.New(1024, m.mtu, "")
ipStack, err := newGVisorStack(endpoint)
if err != nil {
return err
}
if !m.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
if err != nil {
return
}
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
lAddr := udpConn.RemoteAddr()
rAddr := udpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
endpoint.Abort()
return
}
gConn := &gUDPConn{UDPConn: udpConn}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second)
hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
}
m.stack = ipStack
m.endpoint = endpoint
go m.tunLoop()
go m.packetLoop()
return nil
}
func (m *Mixed) tunLoop() {
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
m.wintunLoop(winTun)
return
}
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer)
if err != nil {
return
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
packet := packetBuffer[PacketOffset:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet)
case 6:
err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
}
}
}
func (m *Mixed) wintunLoop(winTun WinTun) {
for {
packet, release, err := winTun.ReadPacket()
if err != nil {
return
}
if len(packet) < clashtcpip.IPv4PacketMinLength {
release()
continue
}
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = m.processIPv4(packet)
case 6:
err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
}
release()
}
}
func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error {
destination := packet.DestinationIP()
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
return common.Error(m.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv4TCP(packet, packet.Payload())
case clashtcpip.UDP:
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMP:
return m.processIPv4ICMP(packet, packet.Payload())
default:
return common.Error(m.tun.Write(packet))
}
}
func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error {
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(m.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return m.processIPv6TCP(packet, packet.Payload())
case clashtcpip.UDP:
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
pkt.DecRef()
return nil
case clashtcpip.ICMPv6:
return m.processIPv6ICMP(packet, packet.Payload())
default:
return common.Error(m.tun.Write(packet))
}
}
func (m *Mixed) packetLoop() {
for {
packet := m.endpoint.ReadContext(m.ctx)
if packet == nil {
break
}
bufio.WriteVectorised(m.writer, packet.AsSlices())
packet.DecRef()
}
}
func (m *Mixed) Close() error {
m.endpoint.Attach(nil)
m.stack.Close()
for _, endpoint := range m.stack.CleanupEndpoints() {
endpoint.Abort()
}
return m.System.Close()
}

View File

@@ -23,7 +23,6 @@ type System struct {
tun Tun
tunName string
mtu uint32
router Router
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
@@ -32,6 +31,7 @@ type System struct {
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
udpTimeout int64
tcpListener net.Listener
tcpListener6 net.Listener
@@ -39,7 +39,6 @@ type System struct {
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service[netip.AddrPort]
routeMapping *RouteMapping
bindInterface bool
interfaceFinder control.InterfaceFinder
}
@@ -58,17 +57,14 @@ func NewSystem(options StackOptions) (Stack, error) {
tunName: options.Name,
mtu: options.MTU,
udpTimeout: options.UDPTimeout,
router: options.Router,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
broadcastAddr: BroadcastAddr(options.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
}
if stack.router != nil {
stack.routeMapping = NewRouteMapping(options.UDPTimeout)
}
if len(options.Inet4Address) > 0 {
if options.Inet4Address[0].Bits() == 32 {
return nil, E.New("need one more IPv4 address in first prefix for system stack")
@@ -97,6 +93,19 @@ func (s *System) Close() error {
}
func (s *System) Start() error {
err := s.start()
if err != nil {
return err
}
go s.tunLoop()
return nil
}
func (s *System) start() error {
err := fixWindowsFirewall()
if err != nil {
return E.Cause(err, "fix windows firewall for system stack")
}
var listener net.ListenConfig
if s.bindInterface {
listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
@@ -127,7 +136,6 @@ func (s *System) Start() error {
}
s.tcpNat = NewNat(s.ctx, time.Second*time.Duration(s.udpTimeout))
s.udpNat = udpnat.New[netip.AddrPort](s.udpTimeout, s.handler)
go s.tunLoop()
return nil
}
@@ -136,20 +144,16 @@ func (s *System) tunLoop() {
s.wintunLoop(winTun)
return
}
_packetBuffer := buf.StackNewSize(int(s.mtu))
defer common.KeepAlive(_packetBuffer)
packetBuffer := common.Dup(_packetBuffer)
defer packetBuffer.Release()
packetSlice := packetBuffer.Slice()
packetBuffer := make([]byte, s.mtu+PacketOffset)
for {
n, err := s.tun.Read(packetSlice)
n, err := s.tun.Read(packetBuffer)
if err != nil {
return
}
if n < clashtcpip.IPv4PacketMinLength {
continue
}
packet := packetSlice[PacketOffset:n]
packet := packetBuffer[PacketOffset:n]
switch ipVersion := packet[0] >> 4; ipVersion {
case 4:
err = s.processIPv4(packet)
@@ -231,6 +235,10 @@ func (s *System) acceptLoop(listener net.Listener) {
}
func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
destination := packet.DestinationIP()
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv4TCP(packet, packet.Payload())
@@ -244,6 +252,9 @@ func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
}
func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error {
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv6TCP(packet, packet.Payload())
@@ -271,21 +282,6 @@ func (s *System) processIPv4TCP(packet clashtcpip.IPv4Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
if s.router != nil {
session := RouteSession{4, syscall.IPPROTO_TCP, source, destination}
action := s.routeMapping.Lookup(session, func() RouteAction {
return s.router.RouteConnection(session, &systemTCPDirectPacketWriter4{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send ICMP unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv4 tcp packet")
})
}
}
natPort := s.tcpNat.Lookup(source, destination)
packet.SetSourceIP(s.inet4Address)
header.SetSourcePort(natPort)
@@ -312,21 +308,6 @@ func (s *System) processIPv6TCP(packet clashtcpip.IPv6Packet, header clashtcpip.
packet.SetDestinationIP(session.Source.Addr())
header.SetDestinationPort(session.Source.Port())
} else {
if s.router != nil {
session := RouteSession{6, syscall.IPPROTO_TCP, source, destination}
action := s.routeMapping.Lookup(session, func() RouteAction {
return s.router.RouteConnection(session, &systemTCPDirectPacketWriter6{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send RST
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv6 tcp packet")
})
}
}
natPort := s.tcpNat.Lookup(source, destination)
packet.SetSourceIP(s.inet6Address)
header.SetSourcePort(natPort)
@@ -345,26 +326,14 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
if packet.FragmentOffset() != 0 {
return E.New("ipv4: udp: fragment dropped")
}
if !header.Valid() {
return E.New("ipv4: udp: invalid packet")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
if s.router != nil {
routeSession := RouteSession{4, syscall.IPPROTO_UDP, source, destination}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter4{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv4 udp packet")
})
}
}
data := buf.As(header.Payload())
if data.Len() == 0 {
return nil
@@ -383,26 +352,14 @@ func (s *System) processIPv4UDP(packet clashtcpip.IPv4Packet, header clashtcpip.
}
func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.UDPPacket) error {
if !header.Valid() {
return E.New("ipv6: udp: invalid packet")
}
source := netip.AddrPortFrom(packet.SourceIP(), header.SourcePort())
destination := netip.AddrPortFrom(packet.DestinationIP(), header.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
if s.router != nil {
routeSession := RouteSession{6, syscall.IPPROTO_UDP, source, destination}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemUDPDirectPacketWriter6{s.tun, source})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv6 udp packet")
})
}
}
data := buf.As(header.Payload())
if data.Len() == 0 {
return nil
@@ -421,21 +378,6 @@ func (s *System) processIPv6UDP(packet clashtcpip.IPv6Packet, header clashtcpip.
}
func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip.ICMPPacket) error {
if s.router != nil {
routeSession := RouteSession{4, clashtcpip.ICMP, netip.AddrPortFrom(packet.SourceIP(), 0), netip.AddrPortFrom(packet.DestinationIP(), 0)}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemICMPDirectPacketWriter4{s.tun, packet.SourceIP()})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv4 icmp packet")
})
}
}
if header.Type() != clashtcpip.ICMPTypePingRequest || header.Code() != 0 {
return nil
}
@@ -449,21 +391,6 @@ func (s *System) processIPv4ICMP(packet clashtcpip.IPv4Packet, header clashtcpip
}
func (s *System) processIPv6ICMP(packet clashtcpip.IPv6Packet, header clashtcpip.ICMPv6Packet) error {
if s.router != nil {
routeSession := RouteSession{6, clashtcpip.ICMPv6, netip.AddrPortFrom(packet.SourceIP(), 0), netip.AddrPortFrom(packet.DestinationIP(), 0)}
action := s.routeMapping.Lookup(routeSession, func() RouteAction {
return s.router.RouteConnection(routeSession, &systemICMPDirectPacketWriter6{s.tun, packet.SourceIP()})
})
switch actionType := action.(type) {
case *ActionBlock:
// TODO: send icmp unreachable
return nil
case *ActionDirect:
return E.Append(nil, actionType.WritePacket(buf.As(packet).ToOwned()), func(err error) error {
return E.Cause(err, "route ipv6 icmp packet")
})
}
}
if header.Type() != clashtcpip.ICMPv6EchoRequest || header.Code() != 0 {
return nil
}
@@ -567,7 +494,7 @@ type systemUDPPacketWriter4 struct {
}
func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.StackNewSize(len(w.header) + buffer.Len())
newPacket := buf.NewSize(len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
@@ -591,7 +518,7 @@ type systemUDPPacketWriter6 struct {
}
func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.StackNewSize(len(w.header) + buffer.Len())
newPacket := buf.NewSize(len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())

View File

@@ -0,0 +1,7 @@
//go:build !windows
package tun
func fixWindowsFirewall() error {
return nil
}

25
stack_system_windows.go Normal file
View File

@@ -0,0 +1,25 @@
package tun
import (
"os"
"path/filepath"
"github.com/sagernet/sing-tun/internal/winfw"
)
func fixWindowsFirewall() error {
absPath, err := filepath.Abs(os.Args[0])
if err != nil {
return err
}
rule := winfw.FWRule{
Name: "sing-tun (" + absPath + ")",
ApplicationName: absPath,
Enabled: true,
Protocol: winfw.NET_FW_IP_PROTOCOL_TCP,
Direction: winfw.NET_FW_RULE_DIR_IN,
Action: winfw.NET_FW_ACTION_ALLOW,
}
_, err = winfw.FirewallRuleAddAdvanced(rule)
return err
}

39
tun.go
View File

@@ -10,6 +10,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ranges"
)
@@ -22,6 +23,7 @@ type Handler interface {
type Tun interface {
io.ReadWriter
CreateVectorisedWriter() N.VectorisedWriter
Close() error
}
@@ -31,22 +33,27 @@ type WinTun interface {
}
type Options struct {
Name string
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
MTU uint32
AutoRoute bool
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix
IncludeUID []ranges.Range[uint32]
ExcludeUID []ranges.Range[uint32]
IncludeAndroidUser []int
IncludePackage []string
ExcludePackage []string
InterfaceMonitor DefaultInterfaceMonitor
TableIndex int
FileDescriptor int
Name string
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
MTU uint32
AutoRoute bool
StrictRoute bool
Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix
Inet4RouteExcludeAddress []netip.Prefix
Inet6RouteExcludeAddress []netip.Prefix
IncludeInterface []string
ExcludeInterface []string
IncludeUID []ranges.Range[uint32]
ExcludeUID []ranges.Range[uint32]
IncludeAndroidUser []int
IncludePackage []string
ExcludePackage []string
InterfaceMonitor DefaultInterfaceMonitor
TableIndex int
FileDescriptor int
Logger logger.Logger
}
func CalculateInterfaceName(name string) (tunName string) {

View File

@@ -10,6 +10,7 @@ import (
"unsafe"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
@@ -25,8 +26,8 @@ type NativeTun struct {
tunFile *os.File
tunWriter N.VectorisedWriter
mtu uint32
inet4Address string
inet6Address string
inet4Address [4]byte
inet6Address [16]byte
}
func New(options Options) (Tun, error) {
@@ -57,10 +58,10 @@ func New(options Options) (Tun, error) {
mtu: options.MTU,
}
if len(options.Inet4Address) > 0 {
nativeTun.inet4Address = string(options.Inet4Address[0].Addr().AsSlice())
nativeTun.inet4Address = options.Inet4Address[0].Addr().As4()
}
if len(options.Inet6Address) > 0 {
nativeTun.inet6Address = string(options.Inet6Address[0].Addr().AsSlice())
nativeTun.inet6Address = options.Inet6Address[0].Addr().As16()
}
var ok bool
nativeTun.tunWriter, ok = bufio.CreateVectorisedWriter(nativeTun.tunFile)
@@ -101,6 +102,20 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
return
}
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return t
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
var packetHeader []byte
if buffers[0].Byte(0)>>4 == 4 {
packetHeader = packetHeader4[:]
} else {
packetHeader = packetHeader6[:]
}
return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
}
func (t *NativeTun) Close() error {
flushDNSCache()
return t.tunFile.Close()
@@ -248,43 +263,16 @@ func configure(tunFd int, ifIndex int, name string, options Options) error {
}
}
if options.AutoRoute {
if len(options.Inet4Address) > 0 {
var routes []netip.Prefix
if len(options.Inet4RouteAddress) > 0 {
routes = append(options.Inet4RouteAddress, netip.PrefixFrom(options.Inet4Address[0].Addr().Next(), 32))
var routeRanges []netip.Prefix
routeRanges, err = options.BuildAutoRouteRanges(false)
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = addRoute(routeRange, options.Inet4Address[0].Addr())
} else {
routes = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 0, 0, 0}), 8),
netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 0, 0, 0}), 7),
netip.PrefixFrom(netip.AddrFrom4([4]byte{4, 0, 0, 0}), 6),
netip.PrefixFrom(netip.AddrFrom4([4]byte{8, 0, 0, 0}), 5),
netip.PrefixFrom(netip.AddrFrom4([4]byte{16, 0, 0, 0}), 4),
netip.PrefixFrom(netip.AddrFrom4([4]byte{32, 0, 0, 0}), 3),
netip.PrefixFrom(netip.AddrFrom4([4]byte{64, 0, 0, 0}), 2),
netip.PrefixFrom(netip.AddrFrom4([4]byte{128, 0, 0, 0}), 1),
}
err = addRoute(routeRange, options.Inet6Address[0].Addr())
}
for _, subnet := range routes {
err = addRoute(subnet, options.Inet4Address[0].Addr())
if err != nil {
return E.Cause(err, "add ipv4 route ", subnet)
}
}
}
if len(options.Inet6Address) > 0 {
var routes []netip.Prefix
if len(options.Inet6RouteAddress) > 0 {
routes = append(options.Inet6RouteAddress, netip.PrefixFrom(options.Inet6Address[0].Addr().Next(), 128))
} else {
routes = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom16([16]byte{32, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), 3),
}
}
for _, subnet := range routes {
err = addRoute(subnet, options.Inet6Address[0].Addr())
if err != nil {
return E.Cause(err, "add ipv6 route ", subnet)
}
if err != nil {
return E.Cause(err, "add route: ", routeRange)
}
}
flushDNSCache()

View File

@@ -3,14 +3,11 @@
package tun
import (
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
var _ GVisorTun = (*NativeTun)(nil)
@@ -54,37 +51,33 @@ func (e *DarwinEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
}
func (e *DarwinEndpoint) dispatchLoop() {
_buffer := buf.StackNewSize(int(e.tun.mtu) + 4)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
data := buffer.FreeBytes()
packetBuffer := make([]byte, e.tun.mtu+4)
for {
n, err := e.tun.tunFile.Read(data)
n, err := e.tun.tunFile.Read(packetBuffer)
if err != nil {
break
}
packet := data[4:n]
packet := packetBuffer[4:n]
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(packet) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
if header.IPv4(packet).DestinationAddress() == tcpip.Address(e.tun.inet4Address) {
e.tun.tunFile.Write(data[:n])
if header.IPv4(packet).DestinationAddress().As4() == e.tun.inet4Address {
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
if header.IPv6(packet).DestinationAddress() == tcpip.Address(e.tun.inet6Address) {
e.tun.tunFile.Write(data[:n])
if header.IPv6(packet).DestinationAddress().As16() == e.tun.inet6Address {
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
default:
e.tun.tunFile.Write(data[:n])
e.tun.tunFile.Write(packetBuffer[:n])
continue
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: bufferv2.MakeWithData(data[4:n]),
Payload: buffer.MakeWithData(packetBuffer[4:n]),
IsForwardedPacket: true,
})
pkt.NetworkProtocolNumber = networkProtocol
@@ -109,7 +102,11 @@ func (e *DarwinEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *DarwinEndpoint) AddHeader(buffer *stack.PacketBuffer) {
func (e *DarwinEndpoint) AddHeader(buffer stack.PacketBufferPtr) {
}
func (e *DarwinEndpoint) ParseHeader(ptr stack.PacketBufferPtr) bool {
return true
}
func (e *DarwinEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {

View File

@@ -12,7 +12,9 @@ import (
"github.com/sagernet/netlink"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
"github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing/common/x/list"
@@ -68,6 +70,10 @@ func (t *NativeTun) Write(p []byte) (n int, err error) {
return t.tunFile.Write(p)
}
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return bufio.NewVectorisedWriter(t.tunFile)
}
var controlPath string
func init() {
@@ -167,7 +173,7 @@ func (t *NativeTun) configure(tunLink netlink.Link) error {
return err
}
setSearchDomainForSystemdResolved(t.options.Name)
t.setSearchDomainForSystemdResolved()
if t.options.AutoRoute && runtime.GOOS == "android" {
t.interfaceCallback = t.options.InterfaceMonitor.RegisterCallback(t.routeUpdate)
@@ -182,57 +188,25 @@ func (t *NativeTun) Close() error {
return E.Errors(t.unsetRoute(), t.unsetRules(), common.Close(common.PtrOrNil(t.tunFile)))
}
func (t *NativeTun) routes(tunLink netlink.Link) []netlink.Route {
var routes []netlink.Route
if len(t.options.Inet4Address) > 0 {
if t.options.AutoRoute {
if len(t.options.Inet4RouteAddress) > 0 {
for _, addr := range t.options.Inet4RouteAddress {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: addr.Addr().AsSlice(),
Mask: net.CIDRMask(addr.Bits(), 32),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
} else {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv4zero,
Mask: net.CIDRMask(0, 32),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
}
func prefixToIPNet(prefix netip.Prefix) *net.IPNet {
return &net.IPNet{
IP: prefix.Addr().AsSlice(),
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()),
}
if len(t.options.Inet6Address) > 0 {
if len(t.options.Inet6RouteAddress) > 0 {
for _, addr := range t.options.Inet6RouteAddress {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: addr.Addr().AsSlice(),
Mask: net.CIDRMask(addr.Bits(), 128),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
} else {
routes = append(routes, netlink.Route{
Dst: &net.IPNet{
IP: net.IPv6zero,
Mask: net.CIDRMask(0, 128),
},
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
})
}
}
func (t *NativeTun) routes(tunLink netlink.Link) ([]netlink.Route, error) {
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return nil, err
}
return routes
return common.Map(routeRanges, func(it netip.Prefix) netlink.Route {
return netlink.Route{
Dst: prefixToIPNet(it),
LinkIndex: tunLink.Attrs().Index,
Table: t.options.TableIndex,
}
}), nil
}
const (
@@ -317,6 +291,110 @@ func (t *NativeTun) rules() []*netlink.Rule {
priority6++
}
}
if len(t.options.IncludeInterface) > 0 {
matchPriority := priority + 2*len(t.options.IncludeInterface) + 1
for _, includeInterface := range t.options.IncludeInterface {
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.IifName = includeInterface
it.Goto = matchPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
it = netlink.NewRule()
it.Priority = priority
it.OifName = includeInterface
it.Goto = matchPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.IifName = includeInterface
it.Goto = matchPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
it = netlink.NewRule()
it.Priority = priority6
it.OifName = includeInterface
it.Goto = matchPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
}
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.Family = unix.AF_INET
it.Goto = nopPriority
rules = append(rules, it)
priority++
it = netlink.NewRule()
it.Priority = matchPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.Family = unix.AF_INET6
it.Goto = nopPriority
rules = append(rules, it)
priority6++
it = netlink.NewRule()
it.Priority = matchPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
} else if len(t.options.ExcludeInterface) > 0 {
for _, excludeInterface := range t.options.ExcludeInterface {
if p4 {
it = netlink.NewRule()
it.Priority = priority
it.IifName = excludeInterface
it.Goto = nopPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
it = netlink.NewRule()
it.Priority = priority
it.OifName = excludeInterface
it.Goto = nopPriority
it.Family = unix.AF_INET
rules = append(rules, it)
priority++
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
it.IifName = excludeInterface
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
it = netlink.NewRule()
it.Priority = priority6
it.OifName = excludeInterface
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
}
}
if runtime.GOOS == "android" && t.options.InterfaceMonitor.AndroidVPNEnabled() {
const protectedFromVPN = 0x20000
@@ -462,36 +540,42 @@ func (t *NativeTun) rules() []*netlink.Rule {
priority++
}
if p6 {
// FIXME: this match connections from public address
if !t.options.StrictRoute {
for _, address := range t.options.Inet6Address {
it = netlink.NewRule()
it.Priority = priority6
it.IifName = "lo"
it.Src = address.Masked()
it.Table = t.options.TableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
}
priority6++
it = netlink.NewRule()
it.Priority = priority6
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority6
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1)
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
it = netlink.NewRule()
it.Priority = priority6
it.Table = t.options.TableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
/*it = netlink.NewRule()
it.Priority = priority
it.Invert = true
it.IifName = "lo"
it.Table = tunTableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = netip.PrefixFrom(netip.IPv6Unspecified(), 128) // not working
it.Table = tunTableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)
it = netlink.NewRule()
it.Priority = priority
it.IifName = "lo"
it.Src = t.options.Inet6Address.Masked()
it.Table = tunTableIndex
it.Family = unix.AF_INET6
rules = append(rules, it)*/
priority6++
}
if p4 {
@@ -510,7 +594,11 @@ func (t *NativeTun) rules() []*netlink.Rule {
}
func (t *NativeTun) setRoute(tunLink netlink.Link) error {
for i, route := range t.routes(tunLink) {
routes, err := t.routes(tunLink)
if err != nil {
return err
}
for i, route := range routes {
err := netlink.RouteAdd(&route)
if err != nil {
return E.Cause(err, "add route ", i)
@@ -541,8 +629,10 @@ func (t *NativeTun) unsetRoute() error {
}
func (t *NativeTun) unsetRoute0(tunLink netlink.Link) error {
for _, route := range t.routes(tunLink) {
_ = netlink.RouteDel(&route)
if routes, err := t.routes(tunLink); err == nil {
for _, route := range routes {
_ = netlink.RouteDel(&route)
}
}
return nil
}
@@ -588,21 +678,33 @@ func (t *NativeTun) resetRules() error {
return t.setRules()
}
func (t *NativeTun) routeUpdate(event int) error {
func (t *NativeTun) routeUpdate(event int) {
if event&EventAndroidVPNUpdate == 0 {
return nil
return
}
err := t.resetRules()
if err != nil {
return E.Cause(err, "reset route")
if t.options.Logger != nil {
t.options.Logger.Error(E.Cause(err, "reset route"))
}
}
return nil
}
func setSearchDomainForSystemdResolved(interfaceName string) {
func (t *NativeTun) setSearchDomainForSystemdResolved() {
ctlPath, err := exec.LookPath("resolvectl")
if err != nil {
return
}
shell.Exec(ctlPath, "domain", interfaceName, "~.").Run()
var dnsServer []netip.Addr
if len(t.options.Inet4Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet4Address[0].Addr().Next())
}
if len(t.options.Inet6Address) > 0 {
dnsServer = append(dnsServer, t.options.Inet6Address[0].Addr().Next())
}
shell.Exec(ctlPath, "domain", t.options.Name, "~.").Start()
if t.options.AutoRoute {
shell.Exec(ctlPath, "default-route", t.options.Name, "true").Start()
shell.Exec(ctlPath, append([]string{"dns", t.options.Name}, common.Map(dnsServer, netip.Addr.String)...)...).Start()
}
}

View File

@@ -3,8 +3,8 @@
package tun
import (
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/link/fdbased"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
)
var _ GVisorTun = (*NativeTun)(nil)

View File

@@ -2,13 +2,17 @@ package tun
import (
"context"
"net/netip"
"os"
"runtime"
"sort"
"strconv"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ranges"
"go4.org/netipx"
)
const (
@@ -96,3 +100,80 @@ func buildExcludedRanges(includeRanges []ranges.Range[uint32], excludeRanges []r
}
return ranges.Merge(uidRanges)
}
const autoRouteUseSubRanges = runtime.GOOS == "darwin"
func (o *Options) BuildAutoRouteRanges(underNetworkExtension bool) ([]netip.Prefix, error) {
var routeRanges []netip.Prefix
if o.AutoRoute && len(o.Inet4Address) > 0 {
var inet4Ranges []netip.Prefix
if len(o.Inet4RouteAddress) > 0 {
inet4Ranges = o.Inet4RouteAddress
} else if autoRouteUseSubRanges && !underNetworkExtension {
inet4Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom4([4]byte{0: 128}), 1),
}
} else {
inet4Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}
}
if len(o.Inet4RouteExcludeAddress) == 0 {
routeRanges = append(routeRanges, inet4Ranges...)
} else {
var builder netipx.IPSetBuilder
for _, inet4Range := range inet4Ranges {
builder.AddPrefix(inet4Range)
}
for _, prefix := range o.Inet4RouteExcludeAddress {
builder.RemovePrefix(prefix)
}
resultSet, err := builder.IPSet()
if err != nil {
return nil, E.Cause(err, "build IPv4 route address")
}
routeRanges = append(routeRanges, resultSet.Prefixes()...)
}
}
if len(o.Inet6Address) > 0 {
var inet6Ranges []netip.Prefix
if len(o.Inet6RouteAddress) > 0 {
inet6Ranges = o.Inet6RouteAddress
} else if autoRouteUseSubRanges && !underNetworkExtension {
inet6Ranges = []netip.Prefix{
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 1}), 8),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 2}), 7),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 4}), 6),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 8}), 5),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 16}), 4),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 32}), 3),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 64}), 2),
netip.PrefixFrom(netip.AddrFrom16([16]byte{0: 128}), 1),
}
} else {
inet6Ranges = []netip.Prefix{netip.PrefixFrom(netip.IPv6Unspecified(), 0)}
}
if len(o.Inet6RouteExcludeAddress) == 0 {
routeRanges = append(routeRanges, inet6Ranges...)
} else {
var builder netipx.IPSetBuilder
for _, inet6Range := range inet6Ranges {
builder.AddPrefix(inet6Range)
}
for _, prefix := range o.Inet6RouteExcludeAddress {
builder.RemovePrefix(prefix)
}
resultSet, err := builder.IPSet()
if err != nil {
return nil, E.Cause(err, "build IPv6 route address")
}
routeRanges = append(routeRanges, resultSet.Prefixes()...)
}
}
return routeRanges, nil
}

View File

@@ -16,7 +16,10 @@ import (
"github.com/sagernet/sing-tun/internal/winipcfg"
"github.com/sagernet/sing-tun/internal/winsys"
"github.com/sagernet/sing-tun/internal/wintun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/windnsapi"
"golang.org/x/sys/windows"
@@ -85,38 +88,22 @@ func (t *NativeTun) configure() error {
return E.Cause(err, "set ipv6 dns")
}
}
if len(t.options.Inet4Address) > 0 || len(t.options.Inet6Address) > 0 {
_ = luid.DisableDNSRegistration()
}
if t.options.AutoRoute {
if len(t.options.Inet4Address) > 0 {
if len(t.options.Inet4RouteAddress) > 0 {
for _, addr := range t.options.Inet4RouteAddress {
err := luid.AddRoute(addr, netip.IPv4Unspecified(), 0)
if err != nil {
return E.Cause(err, "add ipv4 route: ", addr)
}
}
routeRanges, err := t.options.BuildAutoRouteRanges(false)
if err != nil {
return err
}
for _, routeRange := range routeRanges {
if routeRange.Addr().Is4() {
err = luid.AddRoute(routeRange, netip.IPv4Unspecified(), 0)
} else {
err := luid.AddRoute(netip.PrefixFrom(netip.IPv4Unspecified(), 0), netip.IPv4Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv4 route")
}
err = luid.AddRoute(routeRange, netip.IPv6Unspecified(), 0)
}
}
if len(t.options.Inet6Address) > 0 {
if len(t.options.Inet6RouteAddress) > 0 {
for _, addr := range t.options.Inet6RouteAddress {
err := luid.AddRoute(addr, netip.IPv6Unspecified(), 0)
if err != nil {
return E.Cause(err, "add ipv6 route: ", addr)
}
}
} else {
err := luid.AddRoute(netip.PrefixFrom(netip.IPv6Unspecified(), 0), netip.IPv6Unspecified(), 0)
if err != nil {
return E.Cause(err, "set ipv6 route")
}
}
}
err := windnsapi.FlushResolverCache()
err = windnsapi.FlushResolverCache()
if err != nil {
return err
}
@@ -467,6 +454,15 @@ func (t *NativeTun) write(packetElementList [][]byte) (n int, err error) {
return 0, fmt.Errorf("write failed: %w", err)
}
func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
return t
}
func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
defer buf.ReleaseMulti(buffers)
return common.Error(t.write(buf.ToSliceMulti(buffers)))
}
func (t *NativeTun) Close() error {
var err error
t.closeOnce.Do(func() {

View File

@@ -3,10 +3,10 @@
package tun
import (
"gvisor.dev/gvisor/pkg/bufferv2"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
)
var _ GVisorTun = (*NativeTun)(nil)
@@ -51,16 +51,16 @@ func (e *WintunEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
func (e *WintunEndpoint) dispatchLoop() {
for {
var buffer bufferv2.Buffer
var packetBuffer buffer.Buffer
err := e.tun.ReadFunc(func(b []byte) {
buffer = bufferv2.MakeWithData(b)
packetBuffer = buffer.MakeWithData(b)
})
if err != nil {
break
}
ihl, ok := buffer.PullUp(0, 1)
ihl, ok := packetBuffer.PullUp(0, 1)
if !ok {
buffer.Release()
packetBuffer.Release()
continue
}
var networkProtocol tcpip.NetworkProtocolNumber
@@ -70,12 +70,12 @@ func (e *WintunEndpoint) dispatchLoop() {
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
default:
e.tun.Write(buffer.Flatten())
buffer.Release()
e.tun.Write(packetBuffer.Flatten())
packetBuffer.Release()
continue
}
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer,
Payload: packetBuffer,
IsForwardedPacket: true,
})
dispatcher := e.dispatcher
@@ -99,7 +99,11 @@ func (e *WintunEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
func (e *WintunEndpoint) AddHeader(buffer *stack.PacketBuffer) {
func (e *WintunEndpoint) AddHeader(buffer stack.PacketBufferPtr) {
}
func (e *WintunEndpoint) ParseHeader(ptr stack.PacketBufferPtr) bool {
return true
}
func (e *WintunEndpoint) WritePackets(packetBufferList stack.PacketBufferList) (int, tcpip.Error) {