Compare commits

...

50 Commits

Author SHA1 Message Date
世界
67013b321e Fix race codes 2025-09-12 18:02:59 +08:00
世界
0381a06643 ping: Add destination rewriter 2025-09-11 18:51:00 +08:00
世界
960457abba Update dependencies 2025-09-10 22:52:15 +08:00
世界
7de8ff7f20 ping: Fix read ICMPv6 from gVisor 2025-09-09 23:22:25 +08:00
世界
a8cb01e6df Prevent panic when wintun dll fails to load 2025-09-09 18:04:19 +08:00
世界
b76e852f59 ping: Fix reject 2025-08-27 20:29:08 +08:00
世界
d5865f2135 Fix gvisor loopback address 2025-08-27 16:35:53 +08:00
世界
ff49ece55d Fix checksum changes 2025-08-27 15:47:20 +08:00
世界
adc106bcf6 Improve checksum usages 2025-08-27 11:07:26 +08:00
wwqgtxx
f9bbb15bfb ping: Reduce the cost of readMsg 2025-08-27 08:45:12 +08:00
世界
4c43f4af12 ping: Fix not handling reject 2025-08-27 00:30:40 +08:00
世界
055fe13ec0 Revert "ping: Increate mapping capacity"
This reverts commit a24ab73aca.
2025-08-26 23:54:57 +08:00
世界
e8d7fc1bb2 ping: Do not use connected socket on macOS 2025-08-26 23:43:37 +08:00
世界
a24ab73aca ping: Increate mapping capacity 2025-08-26 14:41:14 +08:00
世界
b5f3fecc25 ping: Fix linux route rules 2025-08-26 14:30:21 +08:00
wwqgtxx
4fb5702443 ping: Add needFilter 2025-08-26 13:20:50 +08:00
wwqgtxx
6e4e045620 ping: Limit old requests 2025-08-26 11:47:14 +08:00
世界
79e2d3b56d ping: Clean old requests 2025-08-26 11:09:50 +08:00
世界
e6c64e3f18 Update icmp network name 2025-08-25 22:43:59 +08:00
世界
59e42c0d1f ping: Fix rewriter 2025-08-25 22:04:16 +08:00
世界
ce55929883 ping: Fix missing bind for unix socket 2025-08-25 21:28:11 +08:00
世界
144683d882 ping: Add filter to destination 2025-08-25 20:57:18 +08:00
世界
d0ff7b6f6c Replace usages of common/atomic 2025-08-25 19:59:21 +08:00
wwqgtxx
854e40dc40 ping: Fix icmp checksum 2025-08-25 11:51:22 +08:00
wwqgtxx
a0b34a4be9 ping: Code cleanup 2025-08-25 11:08:57 +08:00
世界
548f51cc9d ping: Fix test 2025-08-25 10:48:44 +08:00
wwqgtxx
ce050baa58 ping: Add bitwiseID 2025-08-25 10:32:59 +08:00
wwqgtxx
06ddb3e0a7 ping: Add comments 2025-08-25 10:18:22 +08:00
世界
c089ffbd6c ping: Fix read ipv4 header on darwin 2025-08-25 09:40:52 +08:00
世界
fe4e54bb0d ping: check invalid ip header 2025-08-25 00:16:56 +08:00
wwqgtxx
58f331b49e ping: fix network 2025-08-25 00:00:02 +08:00
wwqgtxx
0d3df84673 ping: fix UnprivilegedConn 2025-08-24 23:05:48 +08:00
wwqgtxx
dbd8e28fc8 ping: fix healthCheck panic 2025-08-24 21:20:34 +08:00
世界
ff4941daa4 Pass timeout to PrepareConnection 2025-08-24 18:59:55 +08:00
世界
9532c7f1f6 ping: Update style for socket_linux_unprivileged.go 2025-08-24 18:43:22 +08:00
wwqgtxx
ccfe5c0f0f ping: Rewrite UnprivilegedConn 2025-08-24 18:30:46 +08:00
世界
737ebf01c4 ping: Add timeout to destinations 2025-08-24 18:19:39 +08:00
世界
8f6cc9f62e ping: Fix unprivileged response on linux 2025-08-24 15:15:04 +08:00
世界
3faf8cf679 ping: Add test for ident 2025-08-24 14:07:35 +08:00
wwqgtxx
bee7be8598 ping: fix Logs 2025-08-24 12:50:12 +08:00
世界
d53158b8d7 ping: Add logs 2025-08-24 12:39:56 +08:00
世界
7f41766568 ping: Fix control 2025-08-24 12:39:50 +08:00
世界
dd18aa2b86 ping: Fix on android 2025-08-24 10:47:15 +08:00
世界
86d96064d5 ping: Add gVisor destination 2025-08-24 10:36:16 +08:00
世界
12c9fb6a5d Fix gvisor icmp write 2025-08-23 16:51:29 +08:00
世界
a256dca36b Fix ping response for tun address 2025-08-23 16:43:33 +08:00
世界
f46791bc0d Fix gvisor icmp destination 2025-08-23 16:43:33 +08:00
世界
8dbb51cfb7 Add ping client 2025-08-23 16:16:28 +08:00
世界
036d61a0aa Add ping proxy implementation 2025-08-22 14:21:21 +08:00
世界
933bd2b2d5 Add ping proxy support 2025-08-22 10:56:12 +08:00
39 changed files with 2120 additions and 178 deletions

View File

@@ -26,12 +26,14 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
go-version: ^1.25.0
- name: Build
run: |
make test
build_go120:
name: Linux (Go 1.20)
go test -c -o ping_test ./ping
sudo ./ping_test -test.v
build_go124:
name: Linux (Go 1.24)
runs-on: ubuntu-latest
steps:
- name: Checkout
@@ -41,13 +43,15 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.20
go-version: ~1.24
continue-on-error: true
- name: Build
run: |
make test
build_go121:
name: Linux (Go 1.21)
go test -c -o ping_test ./ping
sudo ./ping_test -test.v
build_go123:
name: Linux (Go 1.23)
runs-on: ubuntu-latest
steps:
- name: Checkout
@@ -57,27 +61,13 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.21
continue-on-error: true
- name: Build
run: |
make test
build_go122:
name: Linux (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
go-version: ~1.23
continue-on-error: true
- name: Build
run: |
make test
go test -c -o ping_test ./ping
sudo ./ping_test -test.v
build_windows:
name: Windows
runs-on: windows-latest
@@ -94,6 +84,7 @@ jobs:
- name: Build
run: |
make test
go test -v ./ping
build_darwin:
name: macOS
runs-on: macos-latest
@@ -109,4 +100,7 @@ jobs:
continue-on-error: true
- name: Build
run: |
make test
make test
go test -v ./ping
go test -c -o ping_test ./ping
sudo ./ping_test -test.v

View File

@@ -29,5 +29,5 @@ lint_install:
test:
go build -v .
go test -bench=. ./internal/checksum_test
#go test -v .
#go test -bench=. ./internal/checksum_test
go test -v .

14
go.mod
View File

@@ -1,28 +1,32 @@
module github.com/sagernet/sing-tun
go 1.20
go 1.23.1
require (
github.com/go-ole/go-ole v1.3.0
github.com/google/btree v1.1.3
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
github.com/sagernet/gvisor v0.0.0-20250909151924-850a370d8506
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.7.0-beta.1
github.com/sagernet/sing v0.8.0-beta.1
github.com/stretchr/testify v1.11.1
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
golang.org/x/sys v0.26.0
golang.org/x/net v0.43.0
golang.org/x/sys v0.35.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/time v0.7.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

24
go.sum
View File

@@ -1,4 +1,5 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
@@ -14,30 +15,37 @@ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/gvisor v0.0.0-20250909151924-850a370d8506 h1:x/t3XqWshOlWqRuumpvbUvjtEr/6mJuBXAVovPefbUg=
github.com/sagernet/gvisor v0.0.0-20250909151924-850a370d8506/go.mod h1:QkkPEJLw59/tfxgapHta14UL5qMUah5NXhO0Kw2Kan4=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw=
github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.8.0-beta.1 h1:tBOdh/K/EBdXWuBxUJsZONyxDzyfzjdCF1Yq57QtpE4=
github.com/sagernet/sing v0.8.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -276,8 +276,8 @@ func (b ICMPv6) Payload() []byte {
// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
type ICMPv6ChecksumParams struct {
Header ICMPv6
Src tcpip.Address
Dst tcpip.Address
Src []byte
Dst []byte
PayloadCsum uint16
PayloadLen int
}
@@ -287,7 +287,7 @@ type ICMPv6ChecksumParams struct {
func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
h := params.Header
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src.AsSlice(), params.Dst.AsSlice(), uint16(len(h)+params.PayloadLen))
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen))
xsum = checksum.Combine(xsum, params.PayloadCsum)
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.

View File

@@ -86,18 +86,26 @@ type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
SourceAddr() netip.Addr
SourceAddressSlice() []byte
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
DestinationAddr() netip.Addr
DestinationAddressSlice() []byte
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
SetSourceAddr(netip.Addr)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)

View File

@@ -315,6 +315,10 @@ func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
func (b IPv4) FlagsDarwinRaw() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
// More returns whether the more fragments flag is set.
func (b IPv4) More() bool {
return b.Flags()&IPv4FlagMoreFragments != 0
@@ -330,11 +334,19 @@ func (b IPv4) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[flagsFO:]) << 3
}
func (b IPv4) FragmentOffsetDarwinRaw() uint16 {
return common.NativeEndian.Uint16(b[flagsFO:]) << 3
}
// TotalLength returns the "total length" field of the IPv4 header.
func (b IPv4) TotalLength() uint16 {
return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
}
func (b IPv4) TotalLengthDarwinRaw() uint16 {
return common.NativeEndian.Uint16(b[IPv4TotalLenOffset:]) + uint16(b.HeaderLength())
}
// Checksum returns the checksum field of the IPv4 header.
func (b IPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[xsum:])
@@ -428,6 +440,10 @@ func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
func (b IPv4) SetTotalLengthDarwinRaw(totalLength uint16) {
common.NativeEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
}
// SetChecksum sets the checksum field of the IPv4 header.
func (b IPv4) SetChecksum(v uint16) {
checksum.Put(b[xsum:], v)
@@ -440,6 +456,11 @@ func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
binary.BigEndian.PutUint16(b[flagsFO:], v)
}
func (b IPv4) SetFlagsFragmentOffsetDarwinRaw(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
common.NativeEndian.PutUint16(b[flagsFO:], v)
}
// SetID sets the identification field.
func (b IPv4) SetID(v uint16) {
binary.BigEndian.PutUint16(b[id:], v)
@@ -458,7 +479,10 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
// CalculateChecksum calculates the checksum of the IPv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return checksum.Checksum(b[:b.HeaderLength()], 0)
// return checksum.Checksum(b[:b.HeaderLength()], 0)
xsum0 := checksum.Checksum(b[:xsum], 0)
xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0)
return xsum0
}
// Encode encodes all the fields of the IPv4 header.
@@ -550,7 +574,8 @@ func (b IPv4) IsChecksumValid() bool {
// same set of octets, including the checksum field. If the result
// is all 1 bits (-0 in 1's complement arithmetic), the check
// succeeds.
return b.CalculateChecksum() == 0xffff
//return b.CalculateChecksum() == 0xffff
return checksum.Checksum(b[:b.HeaderLength()], 0) == 0xffff
}
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast

View File

@@ -351,14 +351,18 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) {
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
// return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum)
xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum)
return xsum
}
// IsChecksumValid returns true iff the TCP header's checksum is valid.
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
// return b.CalculateChecksum(xsum) == 0xffff
return checksum.Checksum(b[:b.DataOffset()], xsum) == 0xffff
}
// Options returns a slice that holds the unparsed TCP options in the segment.

View File

@@ -113,15 +113,18 @@ func (b UDP) SetLength(length uint16) {
// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
// Calculate the rest of the checksum.\
// return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
xsum := checksum.Checksum(b[:udpChecksum], partialChecksum)
xsum = checksum.Checksum(b[udpChecksum+2:UDPMinimumSize], xsum)
return xsum
}
// IsChecksumValid returns true iff the UDP header's checksum is valid.
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
return checksum.Checksum(b[:UDPMinimumSize], xsum) == 0xffff
}
// Encode encodes all the fields of the UDP header.

View File

@@ -39,6 +39,10 @@ func closeAdapter(wintun *Adapter) {
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
// and hence a new NLA entry is created for each new adapter.
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
err = procWintunCloseAdapter.Find()
if err != nil {
return
}
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {

View File

@@ -5,9 +5,9 @@ package tun
import (
"errors"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"

View File

@@ -1,42 +0,0 @@
package tun
import (
"strconv"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
F "github.com/sagernet/sing/common/format"
N "github.com/sagernet/sing/common/network"
)
func NetworkName(network uint8) string {
switch tcpip.TransportProtocolNumber(network) {
case header.TCPProtocolNumber:
return N.NetworkTCP
case header.UDPProtocolNumber:
return N.NetworkUDP
case header.ICMPv4ProtocolNumber:
return N.NetworkICMPv4
case header.ICMPv6ProtocolNumber:
return N.NetworkICMPv6
}
return F.ToString(network)
}
func NetworkFromName(name string) uint8 {
switch name {
case N.NetworkTCP:
return uint8(header.TCPProtocolNumber)
case N.NetworkUDP:
return uint8(header.UDPProtocolNumber)
case N.NetworkICMPv4:
return uint8(header.ICMPv4ProtocolNumber)
case N.NetworkICMPv6:
return uint8(header.ICMPv6ProtocolNumber)
}
parseNetwork, err := strconv.ParseUint(name, 10, 8)
if err != nil {
return 0
}
return uint8(parseNetwork)
}

16
ping/cmsg_unix.go Normal file
View File

@@ -0,0 +1,16 @@
//go:build !windows
package ping
import (
"golang.org/x/net/ipv6"
)
func parseIPv6ControlMessage(cmsg []byte) (*ipv6.ControlMessage, error) {
var controlMessage ipv6.ControlMessage
err := controlMessage.Parse(cmsg)
if err != nil {
return nil, err
}
return &controlMessage, nil
}

47
ping/cmsg_windows.go Normal file
View File

@@ -0,0 +1,47 @@
package ping
import (
"fmt"
"unsafe"
"github.com/sagernet/sing/common"
"golang.org/x/net/ipv6"
"golang.org/x/sys/windows"
)
const (
IPV6_HOPLIMIT = 21
IPV6_TCLASS = 39
IPV6_RECVTCLASS = 40
)
var (
alignedSizeofCmsghdr = (sizeofCmsghdr + cmsgAlignTo - 1) & ^(cmsgAlignTo - 1)
sizeofCmsghdr = int(unsafe.Sizeof(windows.WSACMSGHDR{}))
cmsgAlignTo = int(unsafe.Sizeof(uintptr(0)))
)
func cmsgAlign(n int) int {
return (n + cmsgAlignTo - 1) & ^(cmsgAlignTo - 1)
}
func parseIPv6ControlMessage(cmsg []byte) (*ipv6.ControlMessage, error) {
var controlMessage ipv6.ControlMessage
for len(cmsg) >= sizeofCmsghdr {
cmsghdr := (*windows.WSACMSGHDR)(unsafe.Pointer(unsafe.SliceData(cmsg)))
msgLen := int(cmsghdr.Len)
msgSize := cmsgAlign(msgLen)
if msgLen < sizeofCmsghdr || msgSize > len(cmsg) {
return nil, fmt.Errorf("invalid control message length %d", cmsghdr.Len)
}
switch cmsghdr.Type {
case IPV6_TCLASS:
controlMessage.TrafficClass = int(common.NativeEndian.Uint32(cmsg[alignedSizeofCmsghdr : alignedSizeofCmsghdr+4]))
case IPV6_HOPLIMIT:
controlMessage.HopLimit = int(common.NativeEndian.Uint32(cmsg[alignedSizeofCmsghdr : alignedSizeofCmsghdr+4]))
}
cmsg = cmsg[msgSize:]
}
return &controlMessage, nil
}

224
ping/destination.go Normal file
View File

@@ -0,0 +1,224 @@
package ping
import (
"context"
"errors"
"net/netip"
"os"
"runtime"
"sync"
"time"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ tun.DirectRouteDestination = (*Destination)(nil)
type Destination struct {
conn *Conn
ctx context.Context
logger logger.ContextLogger
destination netip.Addr
routeContext tun.DirectRouteContext
timeout time.Duration
requestAccess sync.Mutex
requests map[pingRequest]time.Time
}
type pingRequest struct {
Source netip.Addr
Destination netip.Addr
Identifier uint16
Sequence uint16
}
func ConnectDestination(
ctx context.Context,
logger logger.ContextLogger,
controlFunc control.Func,
destination netip.Addr,
routeContext tun.DirectRouteContext,
timeout time.Duration,
) (tun.DirectRouteDestination, error) {
var (
conn *Conn
err error
)
switch runtime.GOOS {
case "darwin", "ios", "windows":
conn, err = Connect(ctx, false, controlFunc, destination)
default:
conn, err = Connect(ctx, true, controlFunc, destination)
if errors.Is(err, os.ErrPermission) {
conn, err = Connect(ctx, false, controlFunc, destination)
}
}
if err != nil {
return nil, err
}
d := &Destination{
conn: conn,
ctx: ctx,
logger: logger,
destination: destination,
routeContext: routeContext,
timeout: timeout,
requests: make(map[pingRequest]time.Time),
}
go d.loopRead()
return d, nil
}
func (d *Destination) loopRead() {
defer d.Close()
for {
buffer := buf.NewPacket()
err := d.conn.SetReadDeadline(time.Now().Add(d.timeout))
if err != nil {
d.logger.ErrorContext(d.ctx, E.Cause(err, "set read deadline for ICMP conn"))
}
err = d.conn.ReadIP(buffer)
if err != nil {
buffer.Release()
if !E.IsClosed(err) {
d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP echo reply"))
}
return
}
if !d.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
if !ipHdr.IsValid(buffer.Len()) {
d.logger.ErrorContext(d.ctx, E.New("invalid IPv4 header received"))
continue
}
if ipHdr.PayloadLength() < header.ICMPv4MinimumSize {
d.logger.ErrorContext(d.ctx, E.New("invalid ICMPv4 header received"))
continue
}
icmpHdr := header.ICMPv4(ipHdr.Payload())
if d.needFilter() {
if icmpHdr.Type() != header.ICMPv4EchoReply {
continue
}
var requestExists bool
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
d.requestAccess.Lock()
_, loaded := d.requests[request]
if loaded {
requestExists = true
delete(d.requests, request)
}
d.requestAccess.Unlock()
if !requestExists {
continue
}
}
d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
} else {
ipHdr := header.IPv6(buffer.Bytes())
if !ipHdr.IsValid(buffer.Len()) {
d.logger.ErrorContext(d.ctx, E.New("invalid IPv6 header received"))
continue
}
if ipHdr.PayloadLength() < header.ICMPv6MinimumSize {
d.logger.ErrorContext(d.ctx, E.New("invalid ICMPv6 header received"))
continue
}
icmpHdr := header.ICMPv6(ipHdr.Payload())
if d.needFilter() {
if icmpHdr.Type() != header.ICMPv6EchoReply {
continue
}
var requestExists bool
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
d.requestAccess.Lock()
_, loaded := d.requests[request]
if loaded {
requestExists = true
delete(d.requests, request)
}
d.requestAccess.Unlock()
if !requestExists {
continue
}
}
d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
}
err = d.routeContext.WritePacket(buffer.Bytes())
if err != nil {
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
}
buffer.Release()
}
}
func (d *Destination) WritePacket(packet *buf.Buffer) error {
if !d.destination.Is6() {
ipHdr := header.IPv4(packet.Bytes())
if !ipHdr.IsValid(packet.Len()) {
return E.New("invalid IPv4 header")
}
if ipHdr.PayloadLength() < header.ICMPv4MinimumSize {
return E.New("invalid ICMPv4 header")
}
icmpHdr := header.ICMPv4(ipHdr.Payload())
if d.needFilter() {
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
}
d.logger.TraceContext(d.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
} else {
ipHdr := header.IPv6(packet.Bytes())
if !ipHdr.IsValid(packet.Len()) {
return E.New("invalid IPv6 header")
}
if ipHdr.PayloadLength() < header.ICMPv6MinimumSize {
return E.New("invalid ICMPv6 header")
}
icmpHdr := header.ICMPv6(ipHdr.Payload())
if d.needFilter() {
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
}
d.logger.TraceContext(d.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
}
return d.conn.WriteIP(packet)
}
func (d *Destination) needFilter() bool {
return runtime.GOOS != "windows" && !d.conn.isLinuxUnprivileged()
}
func (d *Destination) registerRequest(request pingRequest) {
const requestsLimit = 1024
d.requestAccess.Lock()
defer d.requestAccess.Unlock()
now := time.Now()
var (
oldestRequest pingRequest
oldestCreateAt = now
)
for oldRequest, createdAt := range d.requests {
if now.Sub(createdAt) > d.timeout {
delete(d.requests, oldRequest)
} else if createdAt.Before(oldestCreateAt) {
oldestRequest = oldRequest
oldestCreateAt = createdAt
}
}
if len(d.requests) > requestsLimit {
delete(d.requests, oldestRequest)
}
d.requests[request] = now
}
func (d *Destination) Close() error {
return d.conn.Close()
}
func (d *Destination) IsClosed() bool {
return d.conn.IsClosed()
}

129
ping/destination_gvisor.go Normal file
View File

@@ -0,0 +1,129 @@
//go:build with_gvisor
package ping
import (
"context"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ tun.DirectRouteDestination = (*GVisorDestination)(nil)
type GVisorDestination struct {
ctx context.Context
logger logger.ContextLogger
endpoint tcpip.Endpoint
conn *gonet.TCPConn
rewriter *SourceRewriter
timeout time.Duration
}
func ConnectGVisor(
ctx context.Context, logger logger.ContextLogger,
sourceAddress, destinationAddress netip.Addr,
routeContext tun.DirectRouteContext,
stack *stack.Stack,
bindAddress4, bindAddress6 netip.Addr,
timeout time.Duration,
) (*GVisorDestination, error) {
var (
bindAddress tcpip.Address
wq waiter.Queue
endpoint tcpip.Endpoint
gErr tcpip.Error
)
if !destinationAddress.Is6() {
if !bindAddress4.IsValid() {
return nil, E.New("missing IPv4 interface address")
}
bindAddress = tun.AddressFromAddr(bindAddress4)
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv4ProtocolNumber, header.IPv4ProtocolNumber, &wq, true)
} else {
if !bindAddress6.IsValid() {
return nil, E.New("missing IPv6 interface address")
}
bindAddress = tun.AddressFromAddr(bindAddress6)
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv6ProtocolNumber, header.IPv6ProtocolNumber, &wq, true)
}
if gErr != nil {
return nil, gonet.TranslateNetstackError(gErr)
}
gErr = endpoint.Bind(tcpip.FullAddress{
NIC: 1,
Addr: bindAddress,
})
if gErr != nil {
return nil, gonet.TranslateNetstackError(gErr)
}
gErr = endpoint.Connect(tcpip.FullAddress{
NIC: 1,
Addr: tun.AddressFromAddr(destinationAddress),
})
if gErr != nil {
return nil, gonet.TranslateNetstackError(gErr)
}
endpoint.SocketOptions().SetHeaderIncluded(true)
rewriter := NewSourceRewriter(ctx, logger, bindAddress4, bindAddress6)
rewriter.CreateSession(tun.DirectRouteSession{Source: sourceAddress, Destination: destinationAddress}, routeContext)
destination := &GVisorDestination{
ctx: ctx,
logger: logger,
endpoint: endpoint,
conn: gonet.NewTCPConn(&wq, endpoint),
rewriter: rewriter,
timeout: timeout,
}
go destination.loopRead()
return destination, nil
}
func (d *GVisorDestination) loopRead() {
defer d.endpoint.Close()
for {
buffer := buf.NewPacket()
err := d.conn.SetReadDeadline(time.Now().Add(d.timeout))
if err != nil {
d.logger.ErrorContext(d.ctx, E.Cause(err, "set read deadline for ICMP conn"))
}
n, err := d.conn.Read(buffer.FreeBytes())
if err != nil {
buffer.Release()
if !E.IsClosed(err) {
d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP echo reply"))
}
return
}
buffer.Truncate(n)
_, err = d.rewriter.WriteBack(buffer.Bytes())
if err != nil {
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
}
buffer.Release()
}
}
func (d *GVisorDestination) WritePacket(packet *buf.Buffer) error {
d.rewriter.RewritePacket(packet.Bytes())
return common.Error(d.conn.Write(packet.Bytes()))
}
func (d *GVisorDestination) Close() error {
return d.conn.Close()
}
func (d *GVisorDestination) IsClosed() bool {
return transport.DatagramEndpointState(d.endpoint.State()) == transport.DatagramEndpointStateClosed
}

View File

@@ -0,0 +1,79 @@
package ping
import (
"net/netip"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/buf"
)
type DestinationWriter struct {
tun.DirectRouteDestination
destination netip.Addr
}
func NewDestinationWriter(routeDestination tun.DirectRouteDestination, destination netip.Addr) *DestinationWriter {
return &DestinationWriter{routeDestination, destination}
}
func (w *DestinationWriter) WritePacket(packet *buf.Buffer) error {
var ipHdr header.Network
switch header.IPVersion(packet.Bytes()) {
case header.IPv4Version:
ipHdr = header.IPv4(packet.Bytes())
case header.IPv6Version:
ipHdr = header.IPv6(packet.Bytes())
default:
return w.DirectRouteDestination.WritePacket(packet)
}
ipHdr.SetDestinationAddr(w.destination)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
if ipHdr.TransportProtocol() == header.ICMPv6ProtocolNumber {
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
}
return w.DirectRouteDestination.WritePacket(packet)
}
type ContextDestinationWriter struct {
tun.DirectRouteContext
destination netip.Addr
}
func NewContextDestinationWriter(context tun.DirectRouteContext, destination netip.Addr) *ContextDestinationWriter {
return &ContextDestinationWriter{
context, destination,
}
}
func (w *ContextDestinationWriter) WritePacket(packet []byte) error {
var ipHdr header.Network
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
case header.IPv6Version:
ipHdr = header.IPv6(packet)
default:
return w.DirectRouteContext.WritePacket(packet)
}
ipHdr.SetSourceAddr(w.destination)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
if ipHdr.TransportProtocol() == header.ICMPv6ProtocolNumber {
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
}
return w.DirectRouteContext.WritePacket(packet)
}

24
ping/destination_test.go Normal file
View File

@@ -0,0 +1,24 @@
package ping_test
import (
"context"
"net/netip"
"testing"
"time"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common/logger"
"github.com/stretchr/testify/require"
)
func TestIsClosed(t *testing.T) {
t.Parallel()
destination, err := ping.ConnectDestination(context.Background(), logger.NOP(), nil, netip.MustParseAddr("1.1.1.1"), nil, 30*time.Second)
require.NoError(t, err)
defer destination.Close()
time.Sleep(1 * time.Second)
require.False(t, destination.IsClosed())
destination.Close()
require.True(t, destination.IsClosed())
}

292
ping/ping.go Normal file
View File

@@ -0,0 +1,292 @@
package ping
import (
"context"
"net"
"net/netip"
"reflect"
"runtime"
"sync/atomic"
"time"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
type Conn struct {
ctx context.Context
privileged bool
conn net.Conn
destination netip.Addr
source common.TypedValue[netip.Addr]
closed atomic.Bool
readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
}
func Connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
c := &Conn{
ctx: ctx,
privileged: privileged,
destination: destination,
}
err := c.connect(controlFunc)
if err != nil {
return nil, err
}
return c, nil
}
func (c *Conn) connect(controlFunc control.Func) (err error) {
if c.isLinuxUnprivileged() {
c.conn, err = newUnprivilegedConn(c.ctx, controlFunc, c.destination)
} else {
c.conn, err = connect(c.privileged, controlFunc, c.destination)
}
if err != nil {
return err
}
if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn {
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
var ipAddr *net.IPAddr
n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob)
if err == nil {
addr = M.AddrFromNet(ipAddr)
}
return
}
} else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn {
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
var addrPort netip.AddrPort
n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob)
if err == nil {
addr = addrPort.Addr()
}
return
}
} else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn {
c.readMsg = unprivilegedConn.ReadMsg
} else {
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
}
return
}
func (c *Conn) isLinuxUnprivileged() bool {
return (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged
}
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
if c.destination.Is6() || c.isLinuxUnprivileged() {
if !c.destination.Is6() {
oob := ipv4.NewControlMessage(ipv4.FlagTTL)
buffer.Advance(header.IPv4MinimumSize)
var ttl int
// tos int
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
if err != nil {
return err
}
buffer.Truncate(n)
if oobn > 0 {
var controlMessage ipv4.ControlMessage
err = controlMessage.Parse(oob[:oobn])
if err != nil {
return err
}
ttl = controlMessage.TTL
}
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
}
ipHdr := header.IPv4(buffer.ExtendHeader(header.IPv4MinimumSize))
ipHdr.Encode(&header.IPv4Fields{
// TOS: uint8(tos),
SrcAddr: addr,
DstAddr: c.source.Load(),
Protocol: uint8(header.ICMPv4ProtocolNumber),
TTL: uint8(ttl),
TotalLength: uint16(buffer.Len()),
})
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
} else {
oob := make([]byte, 1024)
buffer.Advance(header.IPv6MinimumSize)
var (
hopLimit int
trafficClass int
)
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
if err != nil {
return err
}
buffer.Truncate(n)
if oobn > 0 {
var controlMessage *ipv6.ControlMessage
controlMessage, err = parseIPv6ControlMessage(oob[:oobn])
if err != nil {
return err
}
hopLimit = controlMessage.HopLimit
trafficClass = controlMessage.TrafficClass
}
icmpHdr := header.ICMPv6(buffer.Bytes())
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: addr.AsSlice(),
Dst: c.source.Load().AsSlice(),
}))
ipHdr := header.IPv6(buffer.ExtendHeader(header.IPv6MinimumSize))
ipHdr.Encode(&header.IPv6Fields{
TrafficClass: uint8(trafficClass),
PayloadLength: uint16(buffer.Len() - header.IPv6MinimumSize),
TransportProtocol: header.ICMPv6ProtocolNumber,
HopLimit: uint8(hopLimit),
SrcAddr: addr,
DstAddr: c.source.Load(),
})
}
} else {
_, err := buffer.ReadOnceFrom(c.conn)
if err != nil {
return err
}
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
// MacOS have different TotalLen and FragOff in ipv4 header from socket api:
// https://stackoverflow.com/questions/13829712/mac-changes-ip-total-length-field/15881825#15881825
// but in the tun api still same data format as other system
ipHdr.SetTotalLength(ipHdr.TotalLengthDarwinRaw())
ipHdr.SetFlagsFragmentOffset(ipHdr.FlagsDarwinRaw(), ipHdr.FragmentOffsetDarwinRaw())
}
if !ipHdr.IsValid(buffer.Len()) {
return E.New("invalid IPv4 header received")
}
ipHdr.SetDestinationAddr(c.source.Load())
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(ipHdr.Payload())
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
ipHdr := header.IPv6(buffer.Bytes())
if !ipHdr.IsValid(buffer.Len()) {
return E.New("invalid IPv6 header received")
}
ipHdr.SetDestinationAddr(c.source.Load())
icmpHdr := header.ICMPv6(ipHdr.Payload())
if !c.isLinuxUnprivileged() {
icmpHdr.SetIdent(^icmpHdr.Ident())
}
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
}
}
return nil
}
func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
_, err := buffer.ReadOnceFrom(c.conn)
if err != nil {
return err
}
if !c.isLinuxUnprivileged() {
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
buffer.Advance(int(ipHdr.HeaderLength()))
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: c.destination.AsSlice(),
Dst: c.source.Load().AsSlice(),
}))
}
}
return nil
}
func (c *Conn) WriteIP(buffer *buf.Buffer) error {
defer buffer.Release()
if !c.destination.Is6() {
ipHdr := header.IPv4(buffer.Bytes())
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv4(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
}
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
return common.Error(c.conn.Write(ipHdr.Payload()))
} else {
ipHdr := header.IPv6(buffer.Bytes())
if !c.isLinuxUnprivileged() {
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
}
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
return common.Error(c.conn.Write(ipHdr.Payload()))
}
}
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
defer buffer.Release()
if !c.isLinuxUnprivileged() {
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(^icmpHdr.Ident())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: c.source.Load().AsSlice(),
Dst: c.destination.AsSlice(),
}))
}
}
return common.Error(c.conn.Write(buffer.Bytes()))
}
func (c *Conn) SetLocalAddr(addr netip.Addr) {
c.source.Store(addr)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
return c.conn.SetReadDeadline(t)
}
func (c *Conn) Close() error {
defer c.closed.Store(true)
return c.conn.Close()
}
func (c *Conn) IsClosed() bool {
return c.closed.Load()
}

198
ping/ping_test.go Normal file
View File

@@ -0,0 +1,198 @@
package ping_test
import (
"context"
"net/netip"
"os"
"runtime"
"testing"
"time"
"github.com/sagernet/gvisor/pkg/rand"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing-tun/ping"
"github.com/sagernet/sing/common/buf"
"github.com/stretchr/testify/require"
)
func TestPing(t *testing.T) {
t.Parallel()
const addr4 = "127.0.0.1"
t.Run("ipv4", func(t *testing.T) {
t.Run("unprivileged", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Run("read-icmp", func(t *testing.T) {
testPingIPv4ReadICMP(t, false, addr4)
})
t.Run("read-ip", func(t *testing.T) {
testPingIPv4ReadIP(t, false, addr4)
})
})
t.Run("privileged", func(t *testing.T) {
if runtime.GOOS != "windows" && os.Getuid() != 0 {
t.SkipNow()
}
t.Run("read-icmp", func(t *testing.T) {
testPingIPv4ReadICMP(t, true, addr4)
})
t.Run("read-ip", func(t *testing.T) {
testPingIPv4ReadIP(t, true, addr4)
})
})
})
// const addr6 = "2606:4700:4700::1001"
const addr6 = "::1"
t.Run("ipv6", func(t *testing.T) {
t.Run("unprivileged", func(t *testing.T) {
if runtime.GOOS == "windows" {
t.SkipNow()
}
t.Run("read-icmp", func(t *testing.T) {
testPingIPv6ReadICMP(t, false, addr6)
})
t.Run("read-ip", func(t *testing.T) {
testPingIPv6ReadIP(t, false, addr6)
})
})
t.Run("privileged", func(t *testing.T) {
if runtime.GOOS != "windows" && os.Getuid() != 0 {
t.SkipNow()
}
t.Run("read-icmp", func(t *testing.T) {
testPingIPv6ReadICMP(t, true, addr6)
})
t.Run("read-ip", func(t *testing.T) {
testPingIPv6ReadIP(t, true, addr6)
})
})
})
}
func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
t.SkipNow()
}
require.NoError(t, err)
request := make(header.ICMPv4, header.ICMPv4MinimumSize)
request.SetType(header.ICMPv4Echo)
request.SetIdent(uint16(rand.Uint32()))
request.SetChecksum(header.ICMPv4Checksum(request, 0))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
conn.SetLocalAddr(netip.MustParseAddr("127.0.0.1"))
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
response := buf.NewPacket()
err = conn.ReadIP(response)
require.NoError(t, err)
if runtime.GOOS == "linux" && privileged {
response.Reset()
err = conn.ReadIP(response)
require.NoError(t, err)
}
ipHdr := header.IPv4(response.Bytes())
require.NotZero(t, ipHdr.TTL())
icmpHdr := header.ICMPv4(ipHdr.Payload())
require.Equal(t, header.ICMPv4EchoReply, icmpHdr.Type())
require.Equal(t, request.Ident(), icmpHdr.Ident())
}
func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
t.SkipNow()
}
require.NoError(t, err)
request := make(header.ICMPv4, header.ICMPv4MinimumSize)
request.SetType(header.ICMPv4Echo)
request.SetIdent(uint16(rand.Uint32()))
request.SetChecksum(header.ICMPv4Checksum(request, 0))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
response := buf.NewPacket()
err = conn.ReadICMP(response)
require.NoError(t, err)
if runtime.GOOS == "linux" && privileged {
response.Reset()
err = conn.ReadICMP(response)
require.NoError(t, err)
}
icmpHdr := header.ICMPv4(response.Bytes())
require.Equal(t, header.ICMPv4EchoReply, icmpHdr.Type())
require.Equal(t, request.Ident(), icmpHdr.Ident())
}
func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
t.SkipNow()
}
require.NoError(t, err)
request := make(header.ICMPv6, header.ICMPv6MinimumSize)
request.SetType(header.ICMPv6EchoRequest)
request.SetIdent(uint16(rand.Uint32()))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
conn.SetLocalAddr(netip.MustParseAddr("::1"))
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
response := buf.NewPacket()
err = conn.ReadIP(response)
require.NoError(t, err)
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" && privileged {
response.Reset()
err = conn.ReadIP(response)
require.NoError(t, err)
}
ipHdr := header.IPv6(response.Bytes())
require.NotZero(t, ipHdr.HopLimit())
icmpHdr := header.ICMPv6(ipHdr.Payload())
require.Equal(t, header.ICMPv6EchoReply, icmpHdr.Type())
require.Equal(t, request.Ident(), icmpHdr.Ident())
}
func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) {
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
t.SkipNow()
}
require.NoError(t, err)
request := make(header.ICMPv6, header.ICMPv6MinimumSize)
request.SetType(header.ICMPv6EchoRequest)
request.SetIdent(uint16(rand.Uint32()))
err = conn.WriteICMP(buf.As(request).ToOwned())
require.NoError(t, err)
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
response := buf.NewPacket()
err = conn.ReadICMP(response)
require.NoError(t, err)
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" && privileged {
response.Reset()
err = conn.ReadICMP(response)
require.NoError(t, err)
}
icmpHdr := header.ICMPv6(response.Bytes())
require.Equal(t, header.ICMPv6EchoReply, icmpHdr.Type())
require.Equal(t, request.Ident(), icmpHdr.Ident())
}

View File

@@ -0,0 +1,189 @@
package ping
import (
"context"
"net"
"net/netip"
"os"
"sync"
"time"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/pipe"
)
type UnprivilegedConn struct {
ctx context.Context
cancel context.CancelFunc
controlFunc control.Func
destination netip.Addr
receiveChan chan *unprivilegedResponse
readDeadline pipe.Deadline
mappingAccess sync.Mutex
mapping map[uint16]net.Conn
}
type unprivilegedResponse struct {
Buffer *buf.Buffer
Cmsg *buf.Buffer
Addr netip.Addr
}
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
conn, err := connect(false, controlFunc, destination)
if err != nil {
return nil, err
}
conn.Close()
ctx, cancel := context.WithCancel(ctx)
return &UnprivilegedConn{
ctx: ctx,
cancel: cancel,
controlFunc: controlFunc,
destination: destination,
receiveChan: make(chan *unprivilegedResponse),
readDeadline: pipe.MakeDeadline(),
mapping: make(map[uint16]net.Conn),
}, nil
}
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
select {
case packet := <-c.receiveChan:
n = copy(b, packet.Buffer.Bytes())
packet.Buffer.Release()
packet.Cmsg.Release()
return
case <-c.readDeadline.Wait():
return 0, os.ErrDeadlineExceeded
case <-c.ctx.Done():
return 0, os.ErrClosed
}
}
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
select {
case packet := <-c.receiveChan:
n = copy(b, packet.Buffer.Bytes())
oobn = copy(oob, packet.Cmsg.Bytes())
addr = packet.Addr
packet.Buffer.Release()
packet.Cmsg.Release()
return
case <-c.readDeadline.Wait():
return 0, 0, netip.Addr{}, os.ErrDeadlineExceeded
case <-c.ctx.Done():
return 0, 0, netip.Addr{}, os.ErrClosed
}
}
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
var identifier uint16
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(b)
identifier = icmpHdr.Ident()
} else {
icmpHdr := header.ICMPv6(b)
identifier = icmpHdr.Ident()
}
c.mappingAccess.Lock()
if c.ctx.Err() != nil {
return 0, c.ctx.Err()
}
conn, loaded := c.mapping[identifier]
if !loaded {
conn, err = connect(false, c.controlFunc, c.destination)
if err != nil {
c.mappingAccess.Unlock()
return
}
go c.fetchResponse(conn.(*net.UDPConn), identifier)
c.mapping[identifier] = conn
}
c.mappingAccess.Unlock()
n, err = conn.Write(b)
if err != nil {
c.removeConn(conn.(*net.UDPConn), identifier)
}
return
}
func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) {
defer c.removeConn(conn, identifier)
for {
buffer := buf.NewPacket()
cmsgBuffer := buf.NewSize(1024)
n, oobN, _, addr, err := conn.ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
if err != nil {
buffer.Release()
cmsgBuffer.Release()
return
}
buffer.Truncate(n)
cmsgBuffer.Truncate(oobN)
if !c.destination.Is6() {
icmpHdr := header.ICMPv4(buffer.Bytes())
icmpHdr.SetIdent(identifier)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
} else {
icmpHdr := header.ICMPv6(buffer.Bytes())
icmpHdr.SetIdent(identifier)
// offload checksum here since we don't have source address here
}
select {
case c.receiveChan <- &unprivilegedResponse{
Buffer: buffer,
Cmsg: cmsgBuffer,
Addr: addr.Addr(),
}:
case <-c.ctx.Done():
buffer.Release()
cmsgBuffer.Release()
return
}
}
}
func (c *UnprivilegedConn) removeConn(conn *net.UDPConn, identifier uint16) {
c.mappingAccess.Lock()
defer c.mappingAccess.Unlock()
_ = conn.Close()
delete(c.mapping, identifier)
}
func (c *UnprivilegedConn) Close() error {
c.mappingAccess.Lock()
defer c.mappingAccess.Unlock()
c.cancel()
for _, conn := range c.mapping {
_ = conn.Close()
}
common.ClearMap(c.mapping)
return nil
}
func (c *UnprivilegedConn) LocalAddr() net.Addr {
return M.Socksaddr{}
}
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
}
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
return os.ErrInvalid
}
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
c.readDeadline.Set(t)
return nil
}
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
return os.ErrInvalid
}

110
ping/socket_unix.go Normal file
View File

@@ -0,0 +1,110 @@
//go:build unix
package ping
import (
"net"
"net/netip"
"os"
"runtime"
"syscall"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/sys/unix"
)
func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
var (
network string
fd int
err error
)
if destination.Is4() {
network = "ip4" // like std's netFD.ctrlNetwork
if !privileged {
fd, err = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_ICMP)
} else {
fd, err = unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP)
}
} else {
network = "ip6" // like std's netFD.ctrlNetwork
if !privileged {
fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_ICMPV6)
} else {
fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_ICMPV6)
}
}
if err != nil {
return nil, E.Cause(err, "socket()")
}
file := os.NewFile(uintptr(fd), "datagram-oriented icmp")
defer file.Close()
if controlFunc != nil {
var syscallConn syscall.RawConn
syscallConn, err = file.SyscallConn()
if err != nil {
return nil, err
}
err = controlFunc(network, destination.String(), syscallConn)
if err != nil {
return nil, err
}
}
if destination.Is4() && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
//err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
//if err != nil {
// return nil, err
//}
err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1)
if err != nil {
return nil, E.Cause(err, "setsockopt()")
}
}
if destination.Is6() {
err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVHOPLIMIT, 1)
if err != nil {
return nil, E.Cause(err, "setsockopt()")
}
err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
if err != nil {
return nil, E.Cause(err, "setsockopt()")
}
}
var bindAddress netip.Addr
if !destination.Is6() {
bindAddress = netip.AddrFrom4([4]byte{})
} else {
bindAddress = netip.AddrFrom16([16]byte{})
}
err = unix.Bind(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(bindAddress, 0)))
if err != nil {
return nil, err
}
if runtime.GOOS == "darwin" && !privileged {
// When running in NetworkExtension on macOS, write to connected socket results in EPIPE.
var packetConn net.PacketConn
packetConn, err = net.FilePacketConn(file)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(packetConn, M.SocksaddrFrom(destination, 0).UDPAddr()), nil
} else {
err = unix.Connect(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(destination, 0)))
if err != nil {
return nil, err
}
var conn net.Conn
conn, err = net.FileConn(file)
if err != nil {
return nil, err
}
return conn, nil
}
}

38
ping/socket_windows.go Normal file
View File

@@ -0,0 +1,38 @@
package ping
import (
"net"
"net/netip"
"syscall"
"github.com/sagernet/sing/common/control"
"golang.org/x/sys/windows"
)
func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
var dialer net.Dialer
dialer.Control = controlFunc
if destination.Is6() {
dialer.Control = control.Append(dialer.Control, func(network, address string, conn syscall.RawConn) error {
return control.Raw(conn, func(fd uintptr) error {
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_HOPLIMIT, 1)
if err != nil {
return err
}
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_RECVTCLASS, 1)
if err != nil {
return err
}
return nil
})
})
}
var network string
if destination.Is4() {
network = "ip4:icmp"
} else {
network = "ip6:ipv6-icmp"
}
return dialer.Dial(network, destination.String())
}

150
ping/source_rewriter.go Normal file
View File

@@ -0,0 +1,150 @@
package ping
import (
"context"
"net/netip"
"sync"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/logger"
)
type SourceRewriter struct {
ctx context.Context
logger logger.ContextLogger
access sync.RWMutex
sessions map[tun.DirectRouteSession]tun.DirectRouteContext
sourceAddress map[uint16]netip.Addr
inet4Address netip.Addr
inet6Address netip.Addr
}
func NewSourceRewriter(ctx context.Context, logger logger.ContextLogger, inet4Address netip.Addr, inet6Address netip.Addr) *SourceRewriter {
return &SourceRewriter{
ctx: ctx,
logger: logger,
sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext),
sourceAddress: make(map[uint16]netip.Addr),
inet4Address: inet4Address,
inet6Address: inet6Address,
}
}
func (m *SourceRewriter) CreateSession(session tun.DirectRouteSession, context tun.DirectRouteContext) {
m.access.Lock()
m.sessions[session] = context
m.access.Unlock()
}
func (m *SourceRewriter) DeleteSession(session tun.DirectRouteSession) {
m.access.Lock()
delete(m.sessions, session)
m.access.Unlock()
}
func (m *SourceRewriter) RewritePacket(packet []byte) {
var ipHdr header.Network
var bindAddr netip.Addr
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
bindAddr = m.inet4Address
case header.IPv6Version:
ipHdr = header.IPv6(packet)
bindAddr = m.inet6Address
default:
return
}
sourceAddr := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(bindAddr)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(ipHdr.Payload())
m.access.Lock()
m.sourceAddress[icmpHdr.Ident()] = sourceAddr
m.access.Unlock()
m.logger.TraceContext(m.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
m.access.Lock()
m.sourceAddress[icmpHdr.Ident()] = sourceAddr
m.access.Unlock()
m.logger.TraceContext(m.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
}
}
func (m *SourceRewriter) WriteBack(packet []byte) (bool, error) {
var ipHdr header.Network
var routeSession tun.DirectRouteSession
switch header.IPVersion(packet) {
case header.IPv4Version:
ipHdr = header.IPv4(packet)
routeSession.Destination = ipHdr.SourceAddr()
case header.IPv6Version:
ipHdr = header.IPv6(packet)
routeSession.Destination = ipHdr.SourceAddr()
default:
return false, nil
}
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(ipHdr.Payload())
m.access.Lock()
ident := icmpHdr.Ident()
source, loaded := m.sourceAddress[ident]
if !loaded {
m.access.Unlock()
return false, nil
}
delete(m.sourceAddress, icmpHdr.Ident())
m.access.Unlock()
routeSession.Source = source
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
m.access.Lock()
ident := icmpHdr.Ident()
source, loaded := m.sourceAddress[ident]
if !loaded {
m.access.Unlock()
return false, nil
}
delete(m.sourceAddress, icmpHdr.Ident())
m.access.Unlock()
routeSession.Source = source
default:
return false, nil
}
m.access.RLock()
context, loaded := m.sessions[routeSession]
m.access.RUnlock()
if !loaded {
return false, nil
}
ipHdr.SetDestinationAddr(routeSession.Source)
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
}
switch ipHdr.TransportProtocol() {
case header.ICMPv4ProtocolNumber:
icmpHdr := header.ICMPv4(ipHdr.Payload())
m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
case header.ICMPv6ProtocolNumber:
icmpHdr := header.ICMPv6(ipHdr.Payload())
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
m.logger.TraceContext(m.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
}
return true, context.WritePacket(packet)
}

View File

@@ -143,12 +143,26 @@ func (r *autoRedirect) setupNFTables() error {
}
}
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
Name: "prerouting_udp",
Name: "prerouting_udp_icmp",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
Type: nftables.ChainTypeFilter,
})
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err = nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_UDP}},
{Key: []byte{unix.IPPROTO_ICMP}},
{Key: []byte{unix.IPPROTO_ICMPV6}},
})
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRoutingUDP,
@@ -157,10 +171,11 @@ func (r *autoRedirect) setupNFTables() error {
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{unix.IPPROTO_UDP},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
Invert: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,

View File

@@ -7,9 +7,9 @@ import (
"errors"
"net"
"net/netip"
"sync/atomic"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"

61
route_direct.go Normal file
View File

@@ -0,0 +1,61 @@
package tun
import (
"net/netip"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/contrab/freelru"
"github.com/sagernet/sing/contrab/maphash"
)
type DirectRouteDestination interface {
WritePacket(packet *buf.Buffer) error
Close() error
IsClosed() bool
}
type DirectRouteSession struct {
// IPVersion uint8
// Network uint8
Source netip.Addr
Destination netip.Addr
}
type DirectRouteMapping struct {
mapping freelru.Cache[DirectRouteSession, DirectRouteDestination]
timeout time.Duration
}
func NewDirectRouteMapping(timeout time.Duration) *DirectRouteMapping {
mapping := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
mapping.SetHealthCheck(func(session DirectRouteSession, action DirectRouteDestination) bool {
if action != nil {
return !action.IsClosed()
}
return true
})
mapping.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
if action != nil {
action.Close()
}
})
mapping.SetLifetime(timeout)
return &DirectRouteMapping{mapping, timeout}
}
func (m *DirectRouteMapping) Lookup(session DirectRouteSession, constructor func(timeout time.Duration) (DirectRouteDestination, error)) (DirectRouteDestination, error) {
var (
created DirectRouteDestination
err error
)
action, _, ok := m.mapping.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) {
created, err = constructor(m.timeout)
return created, err == nil
})
if !ok {
return nil, err
}
return action, nil
}

View File

@@ -12,7 +12,10 @@ import (
"github.com/sagernet/sing/common/logger"
)
var ErrDrop = E.New("drop connections by rule")
var (
ErrDrop = E.New("drop by rule")
ErrReset = E.New("reset by rule")
)
type Stack interface {
Start() error

View File

@@ -15,6 +15,7 @@ import (
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/raw"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
E "github.com/sagernet/sing/common/exceptions"
@@ -28,6 +29,8 @@ const DefaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
tun GVisorTun
inet4Address netip.Addr
inet6Address netip.Addr
inet4LoopbackAddress []netip.Addr
inet6LoopbackAddress []netip.Addr
udpTimeout time.Duration
@@ -52,9 +55,22 @@ func NewGVisor(
return nil, E.New("gVisor stack is unsupported on current platform")
}
var (
inet4Address netip.Addr
inet6Address netip.Addr
)
if len(options.TunOptions.Inet4Address) > 0 {
inet4Address = options.TunOptions.Inet4Address[0].Addr()
}
if len(options.TunOptions.Inet6Address) > 0 {
inet6Address = options.TunOptions.Inet6Address[0].Addr()
}
gStack := &GVisor{
ctx: options.Context,
tun: gTun,
inet4Address: inet4Address,
inet6Address: inet6Address,
inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
udpTimeout: options.UDPTimeout,
@@ -71,12 +87,16 @@ func (t *GVisor) Start() error {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions)
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions, false)
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout)
icmpForwarder.SetLocalAddresses(t.inet4Address, t.inet6Address)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
return nil
@@ -111,11 +131,11 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
}
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
return NewGVisorStackWithOptions(ep, stack.NICOptions{})
return NewGVisorStackWithOptions(ep, stack.NICOptions{}, false)
}
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions, allowRawEndpoint bool) (*stack.Stack, error) {
stackOptions := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
@@ -126,7 +146,11 @@ func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*s
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
}
if allowRawEndpoint {
stackOptions.RawFactory = new(raw.EndpointFactory)
}
ipStack := stack.New(stackOptions)
err := ipStack.CreateNICWithOptions(DefaultNIC, ep, opts)
if err != nil {
return nil, gonet.TranslateNetstackError(err)

244
stack_gvisor_icmp.go Normal file
View File

@@ -0,0 +1,244 @@
//go:build with_gvisor
package tun
import (
"context"
"errors"
"net/netip"
"sync"
"time"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/header/parse"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type ICMPForwarder struct {
ctx context.Context
stack *stack.Stack
inet4Address netip.Addr
inet6Address netip.Addr
handler Handler
mapping *DirectRouteMapping
}
func NewICMPForwarder(
ctx context.Context,
stack *stack.Stack,
handler Handler,
timeout time.Duration,
) *ICMPForwarder {
return &ICMPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
mapping: NewDirectRouteMapping(timeout),
}
}
func (f *ICMPForwarder) SetLocalAddresses(inet4Address, inet6Address netip.Addr) {
f.inet4Address = inet4Address
f.inet6Address = inet6Address
}
func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
ipHdr := header.IPv4(pkt.NetworkHeader().Slice())
icmpHdr := header.ICMPv4(pkt.TransportHeader().Slice())
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
return false
}
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
if destinationAddr != f.inet4Address {
action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
return f.handler.PrepareConnection(
N.NetworkICMP,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&ICMPBackWriter{
stack: f.stack,
packet: pkt,
source: ipHdr.SourceAddress(),
sourceNetwork: header.IPv4ProtocolNumber,
},
timeout,
)
})
if errors.Is(err, ErrReset) {
gWriteUnreachable(f.stack, pkt)
return true
} else if errors.Is(err, ErrDrop) {
return true
}
if action != nil {
// TODO: handle error
_ = icmpWritePacketBuffer(action, pkt)
return true
}
}
icmpHdr.SetType(header.ICMPv4EchoReply)
sourceAddress := ipHdr.SourceAddress()
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
ipHdr.SetDestinationAddress(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], pkt.Data().Checksum()))
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
if gErr != nil {
// TODO: log error
return true
}
route, gErr := f.stack.FindRoute(
DefaultNIC,
id.LocalAddress,
id.RemoteAddress,
header.IPv6ProtocolNumber,
false,
)
if gErr != nil {
// TODO: log error
return true
}
defer route.Release()
outgoingEP.(ipv4.ExportedEndpoint).WritePacketDirect(route, pkt)
return true
} else {
ipHdr := header.IPv6(pkt.NetworkHeader().Slice())
icmpHdr := header.ICMPv6(pkt.TransportHeader().Slice())
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
return false
}
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
if destinationAddr != f.inet6Address {
action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
return f.handler.PrepareConnection(
N.NetworkICMP,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&ICMPBackWriter{
stack: f.stack,
packet: pkt,
source: ipHdr.SourceAddress(),
sourceNetwork: header.IPv6ProtocolNumber,
},
timeout,
)
})
if errors.Is(err, ErrReset) {
gWriteUnreachable(f.stack, pkt)
return true
} else if errors.Is(err, ErrDrop) {
return true
}
if action != nil {
// TODO: handle error
pkt.IncRef()
_ = icmpWritePacketBuffer(action, pkt)
return true
}
}
icmpHdr.SetType(header.ICMPv6EchoReply)
sourceAddress := ipHdr.SourceAddress()
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
ipHdr.SetDestinationAddress(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
PayloadCsum: pkt.Data().Checksum(),
PayloadLen: pkt.Data().Size(),
}))
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
if gErr != nil {
// TODO: log error
return true
}
route, gErr := f.stack.FindRoute(
DefaultNIC,
id.LocalAddress,
id.RemoteAddress,
header.IPv6ProtocolNumber,
false,
)
if gErr != nil {
// TODO: log error
return true
}
defer route.Release()
outgoingEP.(ipv6.ExportedEndpoint).WritePacketDirect(route, pkt)
return true
}
}
type ICMPBackWriter struct {
access sync.Mutex
stack *stack.Stack
packet *stack.PacketBuffer
source tcpip.Address
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *ICMPBackWriter) WritePacket(p []byte) error {
if w.sourceNetwork == header.IPv4ProtocolNumber {
route, err := w.stack.FindRoute(
DefaultNIC,
header.IPv4(p).SourceAddress(),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return gonet.TranslateNetstackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(p),
})
defer packet.DecRef()
parse.IPv4(packet)
err = route.WritePacketDirect(packet)
if err != nil {
return gonet.TranslateNetstackError(err)
}
} else {
route, err := w.stack.FindRoute(
DefaultNIC,
header.IPv6(p).SourceAddress(),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return gonet.TranslateNetstackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(p),
})
parse.IPv6(packet)
defer packet.DecRef()
err = route.WritePacketDirect(packet)
if err != nil {
return gonet.TranslateNetstackError(err)
}
}
return nil
}
func icmpWritePacketBuffer(action DirectRouteDestination, packetBuffer *stack.PacketBuffer) error {
packetSlice := packetBuffer.NetworkHeader().Slice()
packetSlice = append(packetSlice, packetBuffer.TransportHeader().Slice()...)
packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...)
return action.WritePacket(buf.As(packetSlice).ToOwned())
}

View File

@@ -52,7 +52,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
f.tun.WritePacket(pkt)
@@ -66,7 +66,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
ipHdr.SetSourceAddress(inet6LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
f.tun.WritePacket(pkt)
@@ -79,7 +79,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
_, pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination, nil, 0)
if pErr != nil {
r.Complete(!errors.Is(pErr, ErrDrop))
return

View File

@@ -58,7 +58,7 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
func rangeIterate(r stack.Range, fn func(*buffer.View))
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
_, pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination, nil, 0)
if pErr != nil {
if !errors.Is(pErr, ErrDrop) {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))

View File

@@ -237,7 +237,7 @@ func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
pkt.DecRef()
return
case header.ICMPv4ProtocolNumber:
err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
writeBack, err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
}
return
}
@@ -259,7 +259,7 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt)
pkt.DecRef()
case header.ICMPv6ProtocolNumber:
err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
writeBack, err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
}
return
}

View File

@@ -31,10 +31,10 @@ type System struct {
logger logger.Logger
inet4Prefixes []netip.Prefix
inet6Prefixes []netip.Prefix
inet4ServerAddress netip.Addr
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet4NextAddress netip.Addr
inet6Address netip.Addr
inet6NextAddress netip.Addr
broadcastAddr netip.Addr
inet4LoopbackAddress []netip.Addr
inet6LoopbackAddress []netip.Addr
@@ -45,6 +45,7 @@ type System struct {
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service
directNat *DirectRouteMapping
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
@@ -81,17 +82,17 @@ func NewSystem(options StackOptions) (Stack, error) {
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
return nil, E.New("need one more IPv4 address in first prefix for system stack")
}
stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
stack.inet4Address = stack.inet4ServerAddress.Next()
stack.inet4Address = options.TunOptions.Inet4Address[0].Addr()
stack.inet4NextAddress = stack.inet4Address.Next()
}
if len(options.TunOptions.Inet6Address) > 0 {
if !HasNextAddress(options.TunOptions.Inet6Address[0], 1) {
return nil, E.New("need one more IPv6 address in first prefix for system stack")
}
stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
stack.inet6Address = stack.inet6ServerAddress.Next()
stack.inet6Address = options.TunOptions.Inet6Address[0].Addr()
stack.inet6NextAddress = stack.inet6Address.Next()
}
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
if !stack.inet4NextAddress.IsValid() && !stack.inet6NextAddress.IsValid() {
return nil, E.New("missing interface address")
}
return stack, nil
@@ -127,9 +128,9 @@ func (s *System) start() error {
}
var tcpListener net.Listener
var err error
if s.inet4Address.IsValid() {
if s.inet4NextAddress.IsValid() {
for i := 0; i < 3; i++ {
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4Address.String(), "0"))
if !retryableListenError(err) {
break
}
@@ -142,9 +143,9 @@ func (s *System) start() error {
s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener)
}
if s.inet6Address.IsValid() {
if s.inet6NextAddress.IsValid() {
for i := 0; i < 3; i++ {
tcpListener, err = listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
tcpListener, err = listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6Address.String(), "0"))
if !retryableListenError(err) {
break
}
@@ -159,6 +160,7 @@ func (s *System) start() error {
}
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false)
s.directNat = NewDirectRouteMapping(s.udpTimeout)
return nil
}
@@ -361,7 +363,10 @@ func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
writeBack = false
err = s.processIPv4UDP(ipHdr, ipHdr.Payload())
case header.ICMPv4ProtocolNumber:
err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
writeBack, err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
}
if err != nil {
writeBack = false
}
return
}
@@ -377,7 +382,10 @@ func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
case header.UDPProtocolNumber:
err = s.processIPv6UDP(ipHdr, ipHdr.Payload())
case header.ICMPv6ProtocolNumber:
err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
writeBack, err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
}
if err != nil {
writeBack = false
}
return
}
@@ -387,7 +395,7 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return false, nil
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
} else if source.Addr() == s.inet4Address && source.Port() == s.tcpPort {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return false, E.New("ipv4: tcp: session not found: ", destination.Port())
@@ -415,21 +423,19 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet4Address)
ipHdr.SetSourceAddr(s.inet4NextAddress)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
ipHdr.SetDestinationAddr(s.inet4Address)
tcpHdr.SetDestinationPort(s.tcpPort)
}
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
tcpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return true, nil
}
@@ -470,7 +476,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
@@ -485,7 +490,7 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return false, nil
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
} else if source.Addr() == s.inet6Address && source.Port() == s.tcpPort6 {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return false, E.New("ipv6: tcp: session not found: ", destination.Port())
@@ -513,14 +518,13 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet6Address)
ipHdr.SetSourceAddr(s.inet6NextAddress)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
ipHdr.SetDestinationAddr(s.inet6Address)
tcpHdr.SetDestinationPort(s.tcpPort6)
}
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
@@ -601,7 +605,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
}
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
_, pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination, nil, 0)
if pErr != nil {
if !errors.Is(pErr, ErrDrop) {
if source.IsIPv4() {
@@ -643,18 +647,40 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks
return true, s.ctx, writer, nil
}
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error {
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool, error) {
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
return nil
return false, nil
}
sourceAddr := ipHdr.SourceAddr()
destinationAddr := ipHdr.DestinationAddr()
if destinationAddr != s.inet4Address {
action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
return s.handler.PrepareConnection(
N.NetworkICMP,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&systemICMPDirectPacketWriter4{s.tun, s.frontHeadroom + PacketOffset, sourceAddr},
timeout,
)
})
if err != nil {
if errors.Is(err, ErrReset) {
return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable)
} else if errors.Is(err, ErrDrop) {
return false, nil
}
}
if action != nil {
return false, action.WritePacket(buf.As(ipHdr).ToOwned())
}
}
icmpHdr.SetType(header.ICMPv4EchoReply)
sourceAddress := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
ipHdr.SetChecksum(0)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return nil
return true, nil
}
func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error {
@@ -686,7 +712,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
icmpHdr := header.ICMPv4(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(ipHdr.Payload(), 0)))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
@@ -696,9 +722,30 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error {
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) (bool, error) {
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
return nil
return false, nil
}
sourceAddr := ipHdr.SourceAddr()
destinationAddr := ipHdr.DestinationAddr()
if destinationAddr != s.inet6Address {
action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
return s.handler.PrepareConnection(
N.NetworkICMP,
M.SocksaddrFrom(sourceAddr, 0),
M.SocksaddrFrom(destinationAddr, 0),
&systemICMPDirectPacketWriter6{s.tun, s.frontHeadroom + PacketOffset, sourceAddr},
timeout,
)
})
if errors.Is(err, ErrReset) {
return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable)
} else if errors.Is(err, ErrDrop) {
return false, nil
}
if action != nil {
return false, action.WritePacket(buf.As(ipHdr).ToOwned())
}
}
icmpHdr.SetType(header.ICMPv6EchoReply)
sourceAddress := ipHdr.SourceAddr()
@@ -706,10 +753,10 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
Src: ipHdr.SourceAddressSlice(),
Dst: ipHdr.DestinationAddressSlice(),
}))
return nil
return true, nil
}
func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error {
@@ -742,8 +789,8 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
Src: newIPHdr.SourceAddress(),
Dst: newIPHdr.DestinationAddress(),
Src: newIPHdr.SourceAddressSlice(),
Dst: newIPHdr.DestinationAddressSlice(),
PayloadCsum: checksum.Checksum(payload, 0),
PayloadLen: len(payload),
}))
@@ -779,14 +826,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
udpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
@@ -820,7 +865,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
@@ -834,3 +878,46 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemICMPDirectPacketWriter4 struct {
tun Tun
frontHeadroom int
source netip.Addr
}
func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error {
newPacket := buf.NewSize(w.frontHeadroom + len(p))
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(p)
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.SetDestinationAddr(w.source)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemICMPDirectPacketWriter6 struct {
tun Tun
frontHeadroom int
source netip.Addr
}
func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error {
newPacket := buf.NewSize(w.frontHeadroom + len(p))
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(p)
ipHdr := header.IPv6(newPacket.Bytes())
ipHdr.SetDestinationAddr(w.source)
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}

View File

@@ -11,6 +11,7 @@ import (
)
type TCPNat struct {
timeout time.Duration
portIndex uint16
portAccess sync.RWMutex
addrAccess sync.RWMutex
@@ -19,6 +20,7 @@ type TCPNat struct {
}
type TCPSession struct {
sync.Mutex
Source netip.AddrPort
Destination netip.AddrPort
LastActive time.Time
@@ -26,38 +28,41 @@ type TCPSession struct {
func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
natMap := &TCPNat{
timeout: timeout,
portIndex: 10000,
addrMap: make(map[netip.AddrPort]uint16),
portMap: make(map[uint16]*TCPSession),
}
go natMap.loopCheckTimeout(ctx, timeout)
go natMap.loopCheckTimeout(ctx)
return natMap
}
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
ticker := time.NewTicker(timeout)
func (n *TCPNat) loopCheckTimeout(ctx context.Context) {
ticker := time.NewTicker(n.timeout)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n.checkTimeout(timeout)
n.checkTimeout()
case <-ctx.Done():
return
}
}
}
func (n *TCPNat) checkTimeout(timeout time.Duration) {
func (n *TCPNat) checkTimeout() {
now := time.Now()
n.portAccess.Lock()
defer n.portAccess.Unlock()
n.addrAccess.Lock()
defer n.addrAccess.Unlock()
for natPort, session := range n.portMap {
if now.Sub(session.LastActive) > timeout {
session.Lock()
if now.Sub(session.LastActive) > n.timeout {
delete(n.addrMap, session.Source)
delete(n.portMap, natPort)
}
session.Unlock()
}
}
@@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
session := n.portMap[port]
n.portAccess.RUnlock()
if session != nil {
session.LastActive = time.Now()
session.Lock()
if time.Since(session.LastActive) > time.Second {
session.LastActive = time.Now()
}
session.Unlock()
}
return session
}
@@ -78,7 +87,7 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handl
if loaded {
return port, nil
}
pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination))
_, pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), nil, 0)
if pErr != nil {
return 0, pErr
}

View File

@@ -5,6 +5,7 @@ import (
"syscall"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
)
func PacketIPVersion(packet []byte) int {
@@ -13,6 +14,7 @@ func PacketIPVersion(packet []byte) int {
func PacketFillHeader(packet []byte, ipVersion int) {
if PacketOffset > 0 {
common.ClearArray(packet[:3])
switch ipVersion {
case header.IPv4Version:
packet[3] = syscall.AF_INET

13
tun.go
View File

@@ -7,6 +7,7 @@ import (
"runtime"
"strconv"
"strings"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
@@ -18,11 +19,21 @@ import (
)
type Handler interface {
PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error
PrepareConnection(
network string,
source M.Socksaddr,
destination M.Socksaddr,
routeContext DirectRouteContext,
timeout time.Duration,
) (DirectRouteDestination, error)
N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx
}
type DirectRouteContext interface {
WritePacket(packet []byte) error
}
type Tun interface {
io.ReadWriter
Name() (string, error)

View File

@@ -816,14 +816,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
it.Family = unix.AF_INET
rules = append(rules, it)
}
if p4 && !t.options.StrictRoute {
it = netlink.NewRule()
it.Priority = priority
it.IPProto = syscall.IPPROTO_ICMP
it.Goto = nopPriority
it.Family = unix.AF_INET
rules = append(rules, it)
}
if p6 {
it = netlink.NewRule()
it.Priority = priority6
@@ -834,16 +826,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
it.Family = unix.AF_INET6
rules = append(rules, it)
}
if p6 && !t.options.StrictRoute {
it = netlink.NewRule()
it.Priority = priority6
it.IPProto = syscall.IPPROTO_ICMPV6
it.Goto = nopPriority
it.Family = unix.AF_INET6
rules = append(rules, it)
priority6++
}
}
if p4 {
it = netlink.NewRule()

View File

@@ -9,6 +9,7 @@ import (
"net/netip"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
@@ -16,7 +17,6 @@ import (
"github.com/sagernet/sing-tun/internal/winsys"
"github.com/sagernet/sing-tun/internal/wintun"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/atomic"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/windnsapi"