Compare commits
50 Commits
v0.7.0
...
v0.8.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67013b321e | ||
|
|
0381a06643 | ||
|
|
960457abba | ||
|
|
7de8ff7f20 | ||
|
|
a8cb01e6df | ||
|
|
b76e852f59 | ||
|
|
d5865f2135 | ||
|
|
ff49ece55d | ||
|
|
adc106bcf6 | ||
|
|
f9bbb15bfb | ||
|
|
4c43f4af12 | ||
|
|
055fe13ec0 | ||
|
|
e8d7fc1bb2 | ||
|
|
a24ab73aca | ||
|
|
b5f3fecc25 | ||
|
|
4fb5702443 | ||
|
|
6e4e045620 | ||
|
|
79e2d3b56d | ||
|
|
e6c64e3f18 | ||
|
|
59e42c0d1f | ||
|
|
ce55929883 | ||
|
|
144683d882 | ||
|
|
d0ff7b6f6c | ||
|
|
854e40dc40 | ||
|
|
a0b34a4be9 | ||
|
|
548f51cc9d | ||
|
|
ce050baa58 | ||
|
|
06ddb3e0a7 | ||
|
|
c089ffbd6c | ||
|
|
fe4e54bb0d | ||
|
|
58f331b49e | ||
|
|
0d3df84673 | ||
|
|
dbd8e28fc8 | ||
|
|
ff4941daa4 | ||
|
|
9532c7f1f6 | ||
|
|
ccfe5c0f0f | ||
|
|
737ebf01c4 | ||
|
|
8f6cc9f62e | ||
|
|
3faf8cf679 | ||
|
|
bee7be8598 | ||
|
|
d53158b8d7 | ||
|
|
7f41766568 | ||
|
|
dd18aa2b86 | ||
|
|
86d96064d5 | ||
|
|
12c9fb6a5d | ||
|
|
a256dca36b | ||
|
|
f46791bc0d | ||
|
|
8dbb51cfb7 | ||
|
|
036d61a0aa | ||
|
|
933bd2b2d5 |
42
.github/workflows/test.yml
vendored
42
.github/workflows/test.yml
vendored
@@ -26,12 +26,14 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
go-version: ^1.25.0
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go120:
|
||||
name: Linux (Go 1.20)
|
||||
go test -c -o ping_test ./ping
|
||||
sudo ./ping_test -test.v
|
||||
build_go124:
|
||||
name: Linux (Go 1.24)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -41,13 +43,15 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.20
|
||||
go-version: ~1.24
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go121:
|
||||
name: Linux (Go 1.21)
|
||||
go test -c -o ping_test ./ping
|
||||
sudo ./ping_test -test.v
|
||||
build_go123:
|
||||
name: Linux (Go 1.23)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -57,27 +61,13 @@ jobs:
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.21
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go122:
|
||||
name: Linux (Go 1.22)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.22
|
||||
go-version: ~1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
go test -c -o ping_test ./ping
|
||||
sudo ./ping_test -test.v
|
||||
build_windows:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
@@ -94,6 +84,7 @@ jobs:
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
go test -v ./ping
|
||||
build_darwin:
|
||||
name: macOS
|
||||
runs-on: macos-latest
|
||||
@@ -109,4 +100,7 @@ jobs:
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
make test
|
||||
go test -v ./ping
|
||||
go test -c -o ping_test ./ping
|
||||
sudo ./ping_test -test.v
|
||||
4
Makefile
4
Makefile
@@ -29,5 +29,5 @@ lint_install:
|
||||
|
||||
test:
|
||||
go build -v .
|
||||
go test -bench=. ./internal/checksum_test
|
||||
#go test -v .
|
||||
#go test -bench=. ./internal/checksum_test
|
||||
go test -v .
|
||||
|
||||
14
go.mod
14
go.mod
@@ -1,28 +1,32 @@
|
||||
module github.com/sagernet/sing-tun
|
||||
|
||||
go 1.20
|
||||
go 1.23.1
|
||||
|
||||
require (
|
||||
github.com/go-ole/go-ole v1.3.0
|
||||
github.com/google/btree v1.1.3
|
||||
github.com/sagernet/fswatch v0.1.1
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
|
||||
github.com/sagernet/gvisor v0.0.0-20250909151924-850a370d8506
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
|
||||
github.com/sagernet/nftables v0.3.0-beta.4
|
||||
github.com/sagernet/sing v0.7.0-beta.1
|
||||
github.com/sagernet/sing v0.8.0-beta.1
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/sys v0.26.0
|
||||
golang.org/x/net v0.43.0
|
||||
golang.org/x/sys v0.35.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -1,4 +1,5 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
@@ -14,30 +15,37 @@ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
|
||||
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
|
||||
github.com/sagernet/gvisor v0.0.0-20250909151924-850a370d8506 h1:x/t3XqWshOlWqRuumpvbUvjtEr/6mJuBXAVovPefbUg=
|
||||
github.com/sagernet/gvisor v0.0.0-20250909151924-850a370d8506/go.mod h1:QkkPEJLw59/tfxgapHta14UL5qMUah5NXhO0Kw2Kan4=
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/sing v0.7.0-beta.1 h1:2D44KzgeDZwD/R4Ts8jwSUHTRR238a1FpXDrl7l4tVw=
|
||||
github.com/sagernet/sing v0.7.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.8.0-beta.1 h1:tBOdh/K/EBdXWuBxUJsZONyxDzyfzjdCF1Yq57QtpE4=
|
||||
github.com/sagernet/sing v0.8.0-beta.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -276,8 +276,8 @@ func (b ICMPv6) Payload() []byte {
|
||||
// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
|
||||
type ICMPv6ChecksumParams struct {
|
||||
Header ICMPv6
|
||||
Src tcpip.Address
|
||||
Dst tcpip.Address
|
||||
Src []byte
|
||||
Dst []byte
|
||||
PayloadCsum uint16
|
||||
PayloadLen int
|
||||
}
|
||||
@@ -287,7 +287,7 @@ type ICMPv6ChecksumParams struct {
|
||||
func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
|
||||
h := params.Header
|
||||
|
||||
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src.AsSlice(), params.Dst.AsSlice(), uint16(len(h)+params.PayloadLen))
|
||||
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src, params.Dst, uint16(len(h)+params.PayloadLen))
|
||||
xsum = checksum.Combine(xsum, params.PayloadCsum)
|
||||
|
||||
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
|
||||
|
||||
@@ -86,18 +86,26 @@ type Network interface {
|
||||
// SourceAddress returns the value of the "source address" field.
|
||||
SourceAddress() tcpip.Address
|
||||
|
||||
SourceAddr() netip.Addr
|
||||
|
||||
SourceAddressSlice() []byte
|
||||
|
||||
// DestinationAddress returns the value of the "destination address"
|
||||
// field.
|
||||
DestinationAddress() tcpip.Address
|
||||
|
||||
DestinationAddr() netip.Addr
|
||||
|
||||
DestinationAddressSlice() []byte
|
||||
|
||||
// Checksum returns the value of the "checksum" field.
|
||||
Checksum() uint16
|
||||
|
||||
// SetSourceAddress sets the value of the "source address" field.
|
||||
SetSourceAddress(tcpip.Address)
|
||||
|
||||
SetSourceAddr(netip.Addr)
|
||||
|
||||
// SetDestinationAddress sets the value of the "destination address"
|
||||
// field.
|
||||
SetDestinationAddress(tcpip.Address)
|
||||
|
||||
@@ -315,6 +315,10 @@ func (b IPv4) Flags() uint8 {
|
||||
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
|
||||
}
|
||||
|
||||
func (b IPv4) FlagsDarwinRaw() uint8 {
|
||||
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
|
||||
}
|
||||
|
||||
// More returns whether the more fragments flag is set.
|
||||
func (b IPv4) More() bool {
|
||||
return b.Flags()&IPv4FlagMoreFragments != 0
|
||||
@@ -330,11 +334,19 @@ func (b IPv4) FragmentOffset() uint16 {
|
||||
return binary.BigEndian.Uint16(b[flagsFO:]) << 3
|
||||
}
|
||||
|
||||
func (b IPv4) FragmentOffsetDarwinRaw() uint16 {
|
||||
return common.NativeEndian.Uint16(b[flagsFO:]) << 3
|
||||
}
|
||||
|
||||
// TotalLength returns the "total length" field of the IPv4 header.
|
||||
func (b IPv4) TotalLength() uint16 {
|
||||
return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:])
|
||||
}
|
||||
|
||||
func (b IPv4) TotalLengthDarwinRaw() uint16 {
|
||||
return common.NativeEndian.Uint16(b[IPv4TotalLenOffset:]) + uint16(b.HeaderLength())
|
||||
}
|
||||
|
||||
// Checksum returns the checksum field of the IPv4 header.
|
||||
func (b IPv4) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[xsum:])
|
||||
@@ -428,6 +440,10 @@ func (b IPv4) SetTotalLength(totalLength uint16) {
|
||||
binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
|
||||
}
|
||||
|
||||
func (b IPv4) SetTotalLengthDarwinRaw(totalLength uint16) {
|
||||
common.NativeEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength)
|
||||
}
|
||||
|
||||
// SetChecksum sets the checksum field of the IPv4 header.
|
||||
func (b IPv4) SetChecksum(v uint16) {
|
||||
checksum.Put(b[xsum:], v)
|
||||
@@ -440,6 +456,11 @@ func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
|
||||
binary.BigEndian.PutUint16(b[flagsFO:], v)
|
||||
}
|
||||
|
||||
func (b IPv4) SetFlagsFragmentOffsetDarwinRaw(flags uint8, offset uint16) {
|
||||
v := (uint16(flags) << 13) | (offset >> 3)
|
||||
common.NativeEndian.PutUint16(b[flagsFO:], v)
|
||||
}
|
||||
|
||||
// SetID sets the identification field.
|
||||
func (b IPv4) SetID(v uint16) {
|
||||
binary.BigEndian.PutUint16(b[id:], v)
|
||||
@@ -458,7 +479,10 @@ func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
|
||||
|
||||
// CalculateChecksum calculates the checksum of the IPv4 header.
|
||||
func (b IPv4) CalculateChecksum() uint16 {
|
||||
return checksum.Checksum(b[:b.HeaderLength()], 0)
|
||||
// return checksum.Checksum(b[:b.HeaderLength()], 0)
|
||||
xsum0 := checksum.Checksum(b[:xsum], 0)
|
||||
xsum0 = checksum.Checksum(b[xsum+2:b.HeaderLength()], xsum0)
|
||||
return xsum0
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the IPv4 header.
|
||||
@@ -550,7 +574,8 @@ func (b IPv4) IsChecksumValid() bool {
|
||||
// same set of octets, including the checksum field. If the result
|
||||
// is all 1 bits (-0 in 1's complement arithmetic), the check
|
||||
// succeeds.
|
||||
return b.CalculateChecksum() == 0xffff
|
||||
//return b.CalculateChecksum() == 0xffff
|
||||
return checksum.Checksum(b[:b.HeaderLength()], 0) == 0xffff
|
||||
}
|
||||
|
||||
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
|
||||
|
||||
@@ -351,14 +351,18 @@ func (b TCP) SetUrgentPointer(urgentPointer uint16) {
|
||||
// and the checksum of the segment data.
|
||||
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
// return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
xsum := checksum.Checksum(b[:TCPChecksumOffset], partialChecksum)
|
||||
xsum = checksum.Checksum(b[TCPChecksumOffset+2:b.DataOffset()], xsum)
|
||||
return xsum
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the TCP header's checksum is valid.
|
||||
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
// return b.CalculateChecksum(xsum) == 0xffff
|
||||
return checksum.Checksum(b[:b.DataOffset()], xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Options returns a slice that holds the unparsed TCP options in the segment.
|
||||
|
||||
@@ -113,15 +113,18 @@ func (b UDP) SetLength(length uint16) {
|
||||
// CalculateChecksum calculates the checksum of the UDP packet, given the
|
||||
// checksum of the network-layer pseudo-header and the checksum of the payload.
|
||||
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
// Calculate the rest of the checksum.\
|
||||
// return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
xsum := checksum.Checksum(b[:udpChecksum], partialChecksum)
|
||||
xsum = checksum.Checksum(b[udpChecksum+2:UDPMinimumSize], xsum)
|
||||
return xsum
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the UDP header's checksum is valid.
|
||||
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
return checksum.Checksum(b[:UDPMinimumSize], xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the UDP header.
|
||||
|
||||
@@ -39,6 +39,10 @@ func closeAdapter(wintun *Adapter) {
|
||||
// deterministically. If it is set to nil, the GUID is chosen by the system at random,
|
||||
// and hence a new NLA entry is created for each new adapter.
|
||||
func CreateAdapter(name string, tunnelType string, requestedGUID *windows.GUID) (wintun *Adapter, err error) {
|
||||
err = procWintunCloseAdapter.Find()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var name16 *uint16
|
||||
name16, err = windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,9 +5,9 @@ package tun
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func NetworkName(network uint8) string {
|
||||
switch tcpip.TransportProtocolNumber(network) {
|
||||
case header.TCPProtocolNumber:
|
||||
return N.NetworkTCP
|
||||
case header.UDPProtocolNumber:
|
||||
return N.NetworkUDP
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
return N.NetworkICMPv4
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
return N.NetworkICMPv6
|
||||
}
|
||||
return F.ToString(network)
|
||||
}
|
||||
|
||||
func NetworkFromName(name string) uint8 {
|
||||
switch name {
|
||||
case N.NetworkTCP:
|
||||
return uint8(header.TCPProtocolNumber)
|
||||
case N.NetworkUDP:
|
||||
return uint8(header.UDPProtocolNumber)
|
||||
case N.NetworkICMPv4:
|
||||
return uint8(header.ICMPv4ProtocolNumber)
|
||||
case N.NetworkICMPv6:
|
||||
return uint8(header.ICMPv6ProtocolNumber)
|
||||
}
|
||||
parseNetwork, err := strconv.ParseUint(name, 10, 8)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return uint8(parseNetwork)
|
||||
}
|
||||
16
ping/cmsg_unix.go
Normal file
16
ping/cmsg_unix.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !windows
|
||||
|
||||
package ping
|
||||
|
||||
import (
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
func parseIPv6ControlMessage(cmsg []byte) (*ipv6.ControlMessage, error) {
|
||||
var controlMessage ipv6.ControlMessage
|
||||
err := controlMessage.Parse(cmsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &controlMessage, nil
|
||||
}
|
||||
47
ping/cmsg_windows.go
Normal file
47
ping/cmsg_windows.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
IPV6_HOPLIMIT = 21
|
||||
IPV6_TCLASS = 39
|
||||
IPV6_RECVTCLASS = 40
|
||||
)
|
||||
|
||||
var (
|
||||
alignedSizeofCmsghdr = (sizeofCmsghdr + cmsgAlignTo - 1) & ^(cmsgAlignTo - 1)
|
||||
sizeofCmsghdr = int(unsafe.Sizeof(windows.WSACMSGHDR{}))
|
||||
cmsgAlignTo = int(unsafe.Sizeof(uintptr(0)))
|
||||
)
|
||||
|
||||
func cmsgAlign(n int) int {
|
||||
return (n + cmsgAlignTo - 1) & ^(cmsgAlignTo - 1)
|
||||
}
|
||||
|
||||
func parseIPv6ControlMessage(cmsg []byte) (*ipv6.ControlMessage, error) {
|
||||
var controlMessage ipv6.ControlMessage
|
||||
for len(cmsg) >= sizeofCmsghdr {
|
||||
cmsghdr := (*windows.WSACMSGHDR)(unsafe.Pointer(unsafe.SliceData(cmsg)))
|
||||
msgLen := int(cmsghdr.Len)
|
||||
msgSize := cmsgAlign(msgLen)
|
||||
if msgLen < sizeofCmsghdr || msgSize > len(cmsg) {
|
||||
return nil, fmt.Errorf("invalid control message length %d", cmsghdr.Len)
|
||||
}
|
||||
switch cmsghdr.Type {
|
||||
case IPV6_TCLASS:
|
||||
controlMessage.TrafficClass = int(common.NativeEndian.Uint32(cmsg[alignedSizeofCmsghdr : alignedSizeofCmsghdr+4]))
|
||||
case IPV6_HOPLIMIT:
|
||||
controlMessage.HopLimit = int(common.NativeEndian.Uint32(cmsg[alignedSizeofCmsghdr : alignedSizeofCmsghdr+4]))
|
||||
}
|
||||
cmsg = cmsg[msgSize:]
|
||||
}
|
||||
return &controlMessage, nil
|
||||
}
|
||||
224
ping/destination.go
Normal file
224
ping/destination.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
var _ tun.DirectRouteDestination = (*Destination)(nil)
|
||||
|
||||
type Destination struct {
|
||||
conn *Conn
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
destination netip.Addr
|
||||
routeContext tun.DirectRouteContext
|
||||
timeout time.Duration
|
||||
requestAccess sync.Mutex
|
||||
requests map[pingRequest]time.Time
|
||||
}
|
||||
|
||||
type pingRequest struct {
|
||||
Source netip.Addr
|
||||
Destination netip.Addr
|
||||
Identifier uint16
|
||||
Sequence uint16
|
||||
}
|
||||
|
||||
func ConnectDestination(
|
||||
ctx context.Context,
|
||||
logger logger.ContextLogger,
|
||||
controlFunc control.Func,
|
||||
destination netip.Addr,
|
||||
routeContext tun.DirectRouteContext,
|
||||
timeout time.Duration,
|
||||
) (tun.DirectRouteDestination, error) {
|
||||
var (
|
||||
conn *Conn
|
||||
err error
|
||||
)
|
||||
switch runtime.GOOS {
|
||||
case "darwin", "ios", "windows":
|
||||
conn, err = Connect(ctx, false, controlFunc, destination)
|
||||
default:
|
||||
conn, err = Connect(ctx, true, controlFunc, destination)
|
||||
if errors.Is(err, os.ErrPermission) {
|
||||
conn, err = Connect(ctx, false, controlFunc, destination)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &Destination{
|
||||
conn: conn,
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
destination: destination,
|
||||
routeContext: routeContext,
|
||||
timeout: timeout,
|
||||
requests: make(map[pingRequest]time.Time),
|
||||
}
|
||||
go d.loopRead()
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (d *Destination) loopRead() {
|
||||
defer d.Close()
|
||||
for {
|
||||
buffer := buf.NewPacket()
|
||||
err := d.conn.SetReadDeadline(time.Now().Add(d.timeout))
|
||||
if err != nil {
|
||||
d.logger.ErrorContext(d.ctx, E.Cause(err, "set read deadline for ICMP conn"))
|
||||
}
|
||||
err = d.conn.ReadIP(buffer)
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !E.IsClosed(err) {
|
||||
d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP echo reply"))
|
||||
}
|
||||
return
|
||||
}
|
||||
if !d.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
if !ipHdr.IsValid(buffer.Len()) {
|
||||
d.logger.ErrorContext(d.ctx, E.New("invalid IPv4 header received"))
|
||||
continue
|
||||
}
|
||||
if ipHdr.PayloadLength() < header.ICMPv4MinimumSize {
|
||||
d.logger.ErrorContext(d.ctx, E.New("invalid ICMPv4 header received"))
|
||||
continue
|
||||
}
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
if d.needFilter() {
|
||||
if icmpHdr.Type() != header.ICMPv4EchoReply {
|
||||
continue
|
||||
}
|
||||
var requestExists bool
|
||||
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
|
||||
d.requestAccess.Lock()
|
||||
_, loaded := d.requests[request]
|
||||
if loaded {
|
||||
requestExists = true
|
||||
delete(d.requests, request)
|
||||
}
|
||||
d.requestAccess.Unlock()
|
||||
if !requestExists {
|
||||
continue
|
||||
}
|
||||
}
|
||||
d.logger.TraceContext(d.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
} else {
|
||||
ipHdr := header.IPv6(buffer.Bytes())
|
||||
if !ipHdr.IsValid(buffer.Len()) {
|
||||
d.logger.ErrorContext(d.ctx, E.New("invalid IPv6 header received"))
|
||||
continue
|
||||
}
|
||||
if ipHdr.PayloadLength() < header.ICMPv6MinimumSize {
|
||||
d.logger.ErrorContext(d.ctx, E.New("invalid ICMPv6 header received"))
|
||||
continue
|
||||
}
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
if d.needFilter() {
|
||||
if icmpHdr.Type() != header.ICMPv6EchoReply {
|
||||
continue
|
||||
}
|
||||
var requestExists bool
|
||||
request := pingRequest{Source: ipHdr.DestinationAddr(), Destination: ipHdr.SourceAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()}
|
||||
d.requestAccess.Lock()
|
||||
_, loaded := d.requests[request]
|
||||
if loaded {
|
||||
requestExists = true
|
||||
delete(d.requests, request)
|
||||
}
|
||||
d.requestAccess.Unlock()
|
||||
if !requestExists {
|
||||
continue
|
||||
}
|
||||
}
|
||||
d.logger.TraceContext(d.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
}
|
||||
err = d.routeContext.WritePacket(buffer.Bytes())
|
||||
if err != nil {
|
||||
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
|
||||
}
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Destination) WritePacket(packet *buf.Buffer) error {
|
||||
if !d.destination.Is6() {
|
||||
ipHdr := header.IPv4(packet.Bytes())
|
||||
if !ipHdr.IsValid(packet.Len()) {
|
||||
return E.New("invalid IPv4 header")
|
||||
}
|
||||
if ipHdr.PayloadLength() < header.ICMPv4MinimumSize {
|
||||
return E.New("invalid ICMPv4 header")
|
||||
}
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
if d.needFilter() {
|
||||
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
|
||||
}
|
||||
d.logger.TraceContext(d.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
} else {
|
||||
ipHdr := header.IPv6(packet.Bytes())
|
||||
if !ipHdr.IsValid(packet.Len()) {
|
||||
return E.New("invalid IPv6 header")
|
||||
}
|
||||
if ipHdr.PayloadLength() < header.ICMPv6MinimumSize {
|
||||
return E.New("invalid ICMPv6 header")
|
||||
}
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
if d.needFilter() {
|
||||
d.registerRequest(pingRequest{Source: ipHdr.SourceAddr(), Destination: ipHdr.DestinationAddr(), Identifier: icmpHdr.Ident(), Sequence: icmpHdr.Sequence()})
|
||||
}
|
||||
d.logger.TraceContext(d.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
}
|
||||
return d.conn.WriteIP(packet)
|
||||
}
|
||||
|
||||
func (d *Destination) needFilter() bool {
|
||||
return runtime.GOOS != "windows" && !d.conn.isLinuxUnprivileged()
|
||||
}
|
||||
|
||||
func (d *Destination) registerRequest(request pingRequest) {
|
||||
const requestsLimit = 1024
|
||||
d.requestAccess.Lock()
|
||||
defer d.requestAccess.Unlock()
|
||||
now := time.Now()
|
||||
var (
|
||||
oldestRequest pingRequest
|
||||
oldestCreateAt = now
|
||||
)
|
||||
for oldRequest, createdAt := range d.requests {
|
||||
if now.Sub(createdAt) > d.timeout {
|
||||
delete(d.requests, oldRequest)
|
||||
} else if createdAt.Before(oldestCreateAt) {
|
||||
oldestRequest = oldRequest
|
||||
oldestCreateAt = createdAt
|
||||
}
|
||||
}
|
||||
if len(d.requests) > requestsLimit {
|
||||
delete(d.requests, oldestRequest)
|
||||
}
|
||||
d.requests[request] = now
|
||||
}
|
||||
|
||||
func (d *Destination) Close() error {
|
||||
return d.conn.Close()
|
||||
}
|
||||
|
||||
func (d *Destination) IsClosed() bool {
|
||||
return d.conn.IsClosed()
|
||||
}
|
||||
129
ping/destination_gvisor.go
Normal file
129
ping/destination_gvisor.go
Normal file
@@ -0,0 +1,129 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport"
|
||||
"github.com/sagernet/gvisor/pkg/waiter"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
var _ tun.DirectRouteDestination = (*GVisorDestination)(nil)
|
||||
|
||||
type GVisorDestination struct {
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
endpoint tcpip.Endpoint
|
||||
conn *gonet.TCPConn
|
||||
rewriter *SourceRewriter
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func ConnectGVisor(
|
||||
ctx context.Context, logger logger.ContextLogger,
|
||||
sourceAddress, destinationAddress netip.Addr,
|
||||
routeContext tun.DirectRouteContext,
|
||||
stack *stack.Stack,
|
||||
bindAddress4, bindAddress6 netip.Addr,
|
||||
timeout time.Duration,
|
||||
) (*GVisorDestination, error) {
|
||||
var (
|
||||
bindAddress tcpip.Address
|
||||
wq waiter.Queue
|
||||
endpoint tcpip.Endpoint
|
||||
gErr tcpip.Error
|
||||
)
|
||||
if !destinationAddress.Is6() {
|
||||
if !bindAddress4.IsValid() {
|
||||
return nil, E.New("missing IPv4 interface address")
|
||||
}
|
||||
bindAddress = tun.AddressFromAddr(bindAddress4)
|
||||
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv4ProtocolNumber, header.IPv4ProtocolNumber, &wq, true)
|
||||
} else {
|
||||
if !bindAddress6.IsValid() {
|
||||
return nil, E.New("missing IPv6 interface address")
|
||||
}
|
||||
bindAddress = tun.AddressFromAddr(bindAddress6)
|
||||
endpoint, gErr = stack.NewRawEndpoint(header.ICMPv6ProtocolNumber, header.IPv6ProtocolNumber, &wq, true)
|
||||
}
|
||||
if gErr != nil {
|
||||
return nil, gonet.TranslateNetstackError(gErr)
|
||||
}
|
||||
gErr = endpoint.Bind(tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: bindAddress,
|
||||
})
|
||||
if gErr != nil {
|
||||
return nil, gonet.TranslateNetstackError(gErr)
|
||||
}
|
||||
gErr = endpoint.Connect(tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tun.AddressFromAddr(destinationAddress),
|
||||
})
|
||||
if gErr != nil {
|
||||
return nil, gonet.TranslateNetstackError(gErr)
|
||||
}
|
||||
endpoint.SocketOptions().SetHeaderIncluded(true)
|
||||
rewriter := NewSourceRewriter(ctx, logger, bindAddress4, bindAddress6)
|
||||
rewriter.CreateSession(tun.DirectRouteSession{Source: sourceAddress, Destination: destinationAddress}, routeContext)
|
||||
destination := &GVisorDestination{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
endpoint: endpoint,
|
||||
conn: gonet.NewTCPConn(&wq, endpoint),
|
||||
rewriter: rewriter,
|
||||
timeout: timeout,
|
||||
}
|
||||
go destination.loopRead()
|
||||
return destination, nil
|
||||
}
|
||||
|
||||
func (d *GVisorDestination) loopRead() {
|
||||
defer d.endpoint.Close()
|
||||
for {
|
||||
buffer := buf.NewPacket()
|
||||
err := d.conn.SetReadDeadline(time.Now().Add(d.timeout))
|
||||
if err != nil {
|
||||
d.logger.ErrorContext(d.ctx, E.Cause(err, "set read deadline for ICMP conn"))
|
||||
}
|
||||
n, err := d.conn.Read(buffer.FreeBytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
if !E.IsClosed(err) {
|
||||
d.logger.ErrorContext(d.ctx, E.Cause(err, "receive ICMP echo reply"))
|
||||
}
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
_, err = d.rewriter.WriteBack(buffer.Bytes())
|
||||
if err != nil {
|
||||
d.logger.ErrorContext(d.ctx, E.Cause(err, "write ICMP echo reply"))
|
||||
}
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *GVisorDestination) WritePacket(packet *buf.Buffer) error {
|
||||
d.rewriter.RewritePacket(packet.Bytes())
|
||||
return common.Error(d.conn.Write(packet.Bytes()))
|
||||
}
|
||||
|
||||
func (d *GVisorDestination) Close() error {
|
||||
return d.conn.Close()
|
||||
}
|
||||
|
||||
func (d *GVisorDestination) IsClosed() bool {
|
||||
return transport.DatagramEndpointState(d.endpoint.State()) == transport.DatagramEndpointStateClosed
|
||||
}
|
||||
79
ping/destination_rewriter.go
Normal file
79
ping/destination_rewriter.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
)
|
||||
|
||||
type DestinationWriter struct {
|
||||
tun.DirectRouteDestination
|
||||
destination netip.Addr
|
||||
}
|
||||
|
||||
func NewDestinationWriter(routeDestination tun.DirectRouteDestination, destination netip.Addr) *DestinationWriter {
|
||||
return &DestinationWriter{routeDestination, destination}
|
||||
}
|
||||
|
||||
func (w *DestinationWriter) WritePacket(packet *buf.Buffer) error {
|
||||
var ipHdr header.Network
|
||||
switch header.IPVersion(packet.Bytes()) {
|
||||
case header.IPv4Version:
|
||||
ipHdr = header.IPv4(packet.Bytes())
|
||||
case header.IPv6Version:
|
||||
ipHdr = header.IPv6(packet.Bytes())
|
||||
default:
|
||||
return w.DirectRouteDestination.WritePacket(packet)
|
||||
}
|
||||
ipHdr.SetDestinationAddr(w.destination)
|
||||
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
|
||||
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
|
||||
}
|
||||
if ipHdr.TransportProtocol() == header.ICMPv6ProtocolNumber {
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
}
|
||||
return w.DirectRouteDestination.WritePacket(packet)
|
||||
}
|
||||
|
||||
type ContextDestinationWriter struct {
|
||||
tun.DirectRouteContext
|
||||
destination netip.Addr
|
||||
}
|
||||
|
||||
func NewContextDestinationWriter(context tun.DirectRouteContext, destination netip.Addr) *ContextDestinationWriter {
|
||||
return &ContextDestinationWriter{
|
||||
context, destination,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *ContextDestinationWriter) WritePacket(packet []byte) error {
|
||||
var ipHdr header.Network
|
||||
switch header.IPVersion(packet) {
|
||||
case header.IPv4Version:
|
||||
ipHdr = header.IPv4(packet)
|
||||
case header.IPv6Version:
|
||||
ipHdr = header.IPv6(packet)
|
||||
default:
|
||||
return w.DirectRouteContext.WritePacket(packet)
|
||||
}
|
||||
ipHdr.SetSourceAddr(w.destination)
|
||||
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
|
||||
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
|
||||
}
|
||||
if ipHdr.TransportProtocol() == header.ICMPv6ProtocolNumber {
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
}
|
||||
return w.DirectRouteContext.WritePacket(packet)
|
||||
}
|
||||
24
ping/destination_test.go
Normal file
24
ping/destination_test.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package ping_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun/ping"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
destination, err := ping.ConnectDestination(context.Background(), logger.NOP(), nil, netip.MustParseAddr("1.1.1.1"), nil, 30*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer destination.Close()
|
||||
time.Sleep(1 * time.Second)
|
||||
require.False(t, destination.IsClosed())
|
||||
destination.Close()
|
||||
require.True(t, destination.IsClosed())
|
||||
}
|
||||
292
ping/ping.go
Normal file
292
ping/ping.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
type Conn struct {
|
||||
ctx context.Context
|
||||
privileged bool
|
||||
conn net.Conn
|
||||
destination netip.Addr
|
||||
source common.TypedValue[netip.Addr]
|
||||
closed atomic.Bool
|
||||
readMsg func(b, oob []byte) (n, oobn int, addr netip.Addr, err error)
|
||||
}
|
||||
|
||||
func Connect(ctx context.Context, privileged bool, controlFunc control.Func, destination netip.Addr) (*Conn, error) {
|
||||
c := &Conn{
|
||||
ctx: ctx,
|
||||
privileged: privileged,
|
||||
destination: destination,
|
||||
}
|
||||
err := c.connect(controlFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) connect(controlFunc control.Func) (err error) {
|
||||
if c.isLinuxUnprivileged() {
|
||||
c.conn, err = newUnprivilegedConn(c.ctx, controlFunc, c.destination)
|
||||
} else {
|
||||
c.conn, err = connect(c.privileged, controlFunc, c.destination)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ipConn, isIPConn := common.Cast[*net.IPConn](c.conn); isIPConn {
|
||||
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
||||
var ipAddr *net.IPAddr
|
||||
n, oobn, _, ipAddr, err = ipConn.ReadMsgIP(b, oob)
|
||||
if err == nil {
|
||||
addr = M.AddrFromNet(ipAddr)
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if udpConn, isUDPConn := common.Cast[*net.UDPConn](c.conn); isUDPConn {
|
||||
c.readMsg = func(b, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
||||
var addrPort netip.AddrPort
|
||||
n, oobn, _, addrPort, err = udpConn.ReadMsgUDPAddrPort(b, oob)
|
||||
if err == nil {
|
||||
addr = addrPort.Addr()
|
||||
}
|
||||
return
|
||||
}
|
||||
} else if unprivilegedConn, isUnprivilegedConn := c.conn.(*UnprivilegedConn); isUnprivilegedConn {
|
||||
c.readMsg = unprivilegedConn.ReadMsg
|
||||
} else {
|
||||
return E.New("unsupported conn type: ", reflect.TypeOf(c.conn))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Conn) isLinuxUnprivileged() bool {
|
||||
return (runtime.GOOS == "linux" || runtime.GOOS == "android") && !c.privileged
|
||||
}
|
||||
|
||||
func (c *Conn) ReadIP(buffer *buf.Buffer) error {
|
||||
if c.destination.Is6() || c.isLinuxUnprivileged() {
|
||||
if !c.destination.Is6() {
|
||||
oob := ipv4.NewControlMessage(ipv4.FlagTTL)
|
||||
buffer.Advance(header.IPv4MinimumSize)
|
||||
var ttl int
|
||||
// tos int
|
||||
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
if oobn > 0 {
|
||||
var controlMessage ipv4.ControlMessage
|
||||
err = controlMessage.Parse(oob[:oobn])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ttl = controlMessage.TTL
|
||||
}
|
||||
if !c.isLinuxUnprivileged() {
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
}
|
||||
ipHdr := header.IPv4(buffer.ExtendHeader(header.IPv4MinimumSize))
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
// TOS: uint8(tos),
|
||||
SrcAddr: addr,
|
||||
DstAddr: c.source.Load(),
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
TTL: uint8(ttl),
|
||||
TotalLength: uint16(buffer.Len()),
|
||||
})
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
} else {
|
||||
oob := make([]byte, 1024)
|
||||
buffer.Advance(header.IPv6MinimumSize)
|
||||
var (
|
||||
hopLimit int
|
||||
trafficClass int
|
||||
)
|
||||
n, oobn, addr, err := c.readMsg(buffer.FreeBytes(), oob)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
if oobn > 0 {
|
||||
var controlMessage *ipv6.ControlMessage
|
||||
controlMessage, err = parseIPv6ControlMessage(oob[:oobn])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hopLimit = controlMessage.HopLimit
|
||||
trafficClass = controlMessage.TrafficClass
|
||||
}
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
if !c.isLinuxUnprivileged() {
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
}
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: addr.AsSlice(),
|
||||
Dst: c.source.Load().AsSlice(),
|
||||
}))
|
||||
ipHdr := header.IPv6(buffer.ExtendHeader(header.IPv6MinimumSize))
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
TrafficClass: uint8(trafficClass),
|
||||
PayloadLength: uint16(buffer.Len() - header.IPv6MinimumSize),
|
||||
TransportProtocol: header.ICMPv6ProtocolNumber,
|
||||
HopLimit: uint8(hopLimit),
|
||||
SrcAddr: addr,
|
||||
DstAddr: c.source.Load(),
|
||||
})
|
||||
}
|
||||
} else {
|
||||
_, err := buffer.ReadOnceFrom(c.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !c.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "ios" {
|
||||
// MacOS have different TotalLen and FragOff in ipv4 header from socket api:
|
||||
// https://stackoverflow.com/questions/13829712/mac-changes-ip-total-length-field/15881825#15881825
|
||||
// but in the tun api still same data format as other system
|
||||
ipHdr.SetTotalLength(ipHdr.TotalLengthDarwinRaw())
|
||||
ipHdr.SetFlagsFragmentOffset(ipHdr.FlagsDarwinRaw(), ipHdr.FragmentOffsetDarwinRaw())
|
||||
}
|
||||
if !ipHdr.IsValid(buffer.Len()) {
|
||||
return E.New("invalid IPv4 header received")
|
||||
}
|
||||
ipHdr.SetDestinationAddr(c.source.Load())
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
if !c.isLinuxUnprivileged() {
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
}
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
} else {
|
||||
ipHdr := header.IPv6(buffer.Bytes())
|
||||
if !ipHdr.IsValid(buffer.Len()) {
|
||||
return E.New("invalid IPv6 header received")
|
||||
}
|
||||
ipHdr.SetDestinationAddr(c.source.Load())
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
if !c.isLinuxUnprivileged() {
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
}
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) ReadICMP(buffer *buf.Buffer) error {
|
||||
_, err := buffer.ReadOnceFrom(c.conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !c.isLinuxUnprivileged() {
|
||||
if !c.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
buffer.Advance(int(ipHdr.HeaderLength()))
|
||||
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: c.destination.AsSlice(),
|
||||
Dst: c.source.Load().AsSlice(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) WriteIP(buffer *buf.Buffer) error {
|
||||
defer buffer.Release()
|
||||
if !c.destination.Is6() {
|
||||
ipHdr := header.IPv4(buffer.Bytes())
|
||||
if !c.isLinuxUnprivileged() {
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
}
|
||||
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
|
||||
return common.Error(c.conn.Write(ipHdr.Payload()))
|
||||
} else {
|
||||
ipHdr := header.IPv6(buffer.Bytes())
|
||||
if !c.isLinuxUnprivileged() {
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
}
|
||||
c.source.Store(M.AddrFromIP(ipHdr.SourceAddressSlice()))
|
||||
return common.Error(c.conn.Write(ipHdr.Payload()))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) WriteICMP(buffer *buf.Buffer) error {
|
||||
defer buffer.Release()
|
||||
if !c.isLinuxUnprivileged() {
|
||||
if !c.destination.Is6() {
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetIdent(^icmpHdr.Ident())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: c.source.Load().AsSlice(),
|
||||
Dst: c.destination.AsSlice(),
|
||||
}))
|
||||
}
|
||||
}
|
||||
return common.Error(c.conn.Write(buffer.Bytes()))
|
||||
}
|
||||
|
||||
func (c *Conn) SetLocalAddr(addr netip.Addr) {
|
||||
c.source.Store(addr)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
defer c.closed.Store(true)
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) IsClosed() bool {
|
||||
return c.closed.Load()
|
||||
}
|
||||
198
ping/ping_test.go
Normal file
198
ping/ping_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package ping_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/rand"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing-tun/ping"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
t.Parallel()
|
||||
const addr4 = "127.0.0.1"
|
||||
t.Run("ipv4", func(t *testing.T) {
|
||||
t.Run("unprivileged", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.SkipNow()
|
||||
}
|
||||
t.Run("read-icmp", func(t *testing.T) {
|
||||
testPingIPv4ReadICMP(t, false, addr4)
|
||||
})
|
||||
t.Run("read-ip", func(t *testing.T) {
|
||||
testPingIPv4ReadIP(t, false, addr4)
|
||||
})
|
||||
})
|
||||
t.Run("privileged", func(t *testing.T) {
|
||||
if runtime.GOOS != "windows" && os.Getuid() != 0 {
|
||||
t.SkipNow()
|
||||
}
|
||||
t.Run("read-icmp", func(t *testing.T) {
|
||||
testPingIPv4ReadICMP(t, true, addr4)
|
||||
})
|
||||
t.Run("read-ip", func(t *testing.T) {
|
||||
testPingIPv4ReadIP(t, true, addr4)
|
||||
})
|
||||
})
|
||||
})
|
||||
// const addr6 = "2606:4700:4700::1001"
|
||||
const addr6 = "::1"
|
||||
t.Run("ipv6", func(t *testing.T) {
|
||||
t.Run("unprivileged", func(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.SkipNow()
|
||||
}
|
||||
t.Run("read-icmp", func(t *testing.T) {
|
||||
testPingIPv6ReadICMP(t, false, addr6)
|
||||
})
|
||||
t.Run("read-ip", func(t *testing.T) {
|
||||
testPingIPv6ReadIP(t, false, addr6)
|
||||
})
|
||||
})
|
||||
t.Run("privileged", func(t *testing.T) {
|
||||
if runtime.GOOS != "windows" && os.Getuid() != 0 {
|
||||
t.SkipNow()
|
||||
}
|
||||
t.Run("read-icmp", func(t *testing.T) {
|
||||
testPingIPv6ReadICMP(t, true, addr6)
|
||||
})
|
||||
t.Run("read-ip", func(t *testing.T) {
|
||||
testPingIPv6ReadIP(t, true, addr6)
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func testPingIPv4ReadIP(t *testing.T, privileged bool, addr string) {
|
||||
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
|
||||
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
|
||||
t.SkipNow()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
request := make(header.ICMPv4, header.ICMPv4MinimumSize)
|
||||
request.SetType(header.ICMPv4Echo)
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
request.SetChecksum(header.ICMPv4Checksum(request, 0))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.SetLocalAddr(netip.MustParseAddr("127.0.0.1"))
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
|
||||
|
||||
response := buf.NewPacket()
|
||||
err = conn.ReadIP(response)
|
||||
require.NoError(t, err)
|
||||
if runtime.GOOS == "linux" && privileged {
|
||||
response.Reset()
|
||||
err = conn.ReadIP(response)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
ipHdr := header.IPv4(response.Bytes())
|
||||
require.NotZero(t, ipHdr.TTL())
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
require.Equal(t, header.ICMPv4EchoReply, icmpHdr.Type())
|
||||
require.Equal(t, request.Ident(), icmpHdr.Ident())
|
||||
}
|
||||
|
||||
func testPingIPv4ReadICMP(t *testing.T, privileged bool, addr string) {
|
||||
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
|
||||
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
|
||||
t.SkipNow()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
request := make(header.ICMPv4, header.ICMPv4MinimumSize)
|
||||
request.SetType(header.ICMPv4Echo)
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
request.SetChecksum(header.ICMPv4Checksum(request, 0))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
|
||||
|
||||
response := buf.NewPacket()
|
||||
err = conn.ReadICMP(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
if runtime.GOOS == "linux" && privileged {
|
||||
response.Reset()
|
||||
err = conn.ReadICMP(response)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
icmpHdr := header.ICMPv4(response.Bytes())
|
||||
require.Equal(t, header.ICMPv4EchoReply, icmpHdr.Type())
|
||||
require.Equal(t, request.Ident(), icmpHdr.Ident())
|
||||
}
|
||||
|
||||
func testPingIPv6ReadIP(t *testing.T, privileged bool, addr string) {
|
||||
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
|
||||
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
|
||||
t.SkipNow()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
request := make(header.ICMPv6, header.ICMPv6MinimumSize)
|
||||
request.SetType(header.ICMPv6EchoRequest)
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
conn.SetLocalAddr(netip.MustParseAddr("::1"))
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
|
||||
|
||||
response := buf.NewPacket()
|
||||
err = conn.ReadIP(response)
|
||||
require.NoError(t, err)
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" && privileged {
|
||||
response.Reset()
|
||||
err = conn.ReadIP(response)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
ipHdr := header.IPv6(response.Bytes())
|
||||
require.NotZero(t, ipHdr.HopLimit())
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
require.Equal(t, header.ICMPv6EchoReply, icmpHdr.Type())
|
||||
require.Equal(t, request.Ident(), icmpHdr.Ident())
|
||||
}
|
||||
|
||||
func testPingIPv6ReadICMP(t *testing.T, privileged bool, addr string) {
|
||||
conn, err := ping.Connect(context.Background(), privileged, nil, netip.MustParseAddr(addr))
|
||||
if runtime.GOOS == "linux" && err != nil && err.Error() == "socket(): permission denied" {
|
||||
t.SkipNow()
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
request := make(header.ICMPv6, header.ICMPv6MinimumSize)
|
||||
request.SetType(header.ICMPv6EchoRequest)
|
||||
request.SetIdent(uint16(rand.Uint32()))
|
||||
|
||||
err = conn.WriteICMP(buf.As(request).ToOwned())
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, conn.SetReadDeadline(time.Now().Add(3*time.Second)))
|
||||
|
||||
response := buf.NewPacket()
|
||||
err = conn.ReadICMP(response)
|
||||
require.NoError(t, err)
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "linux" && privileged {
|
||||
response.Reset()
|
||||
err = conn.ReadICMP(response)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
icmpHdr := header.ICMPv6(response.Bytes())
|
||||
require.Equal(t, header.ICMPv6EchoReply, icmpHdr.Type())
|
||||
require.Equal(t, request.Ident(), icmpHdr.Ident())
|
||||
}
|
||||
189
ping/socket_linux_unprivileged.go
Normal file
189
ping/socket_linux_unprivileged.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/pipe"
|
||||
)
|
||||
|
||||
type UnprivilegedConn struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
controlFunc control.Func
|
||||
destination netip.Addr
|
||||
receiveChan chan *unprivilegedResponse
|
||||
readDeadline pipe.Deadline
|
||||
mappingAccess sync.Mutex
|
||||
mapping map[uint16]net.Conn
|
||||
}
|
||||
|
||||
type unprivilegedResponse struct {
|
||||
Buffer *buf.Buffer
|
||||
Cmsg *buf.Buffer
|
||||
Addr netip.Addr
|
||||
}
|
||||
|
||||
func newUnprivilegedConn(ctx context.Context, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
|
||||
conn, err := connect(false, controlFunc, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn.Close()
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &UnprivilegedConn{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
controlFunc: controlFunc,
|
||||
destination: destination,
|
||||
receiveChan: make(chan *unprivilegedResponse),
|
||||
readDeadline: pipe.MakeDeadline(),
|
||||
mapping: make(map[uint16]net.Conn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) Read(b []byte) (n int, err error) {
|
||||
select {
|
||||
case packet := <-c.receiveChan:
|
||||
n = copy(b, packet.Buffer.Bytes())
|
||||
packet.Buffer.Release()
|
||||
packet.Cmsg.Release()
|
||||
return
|
||||
case <-c.readDeadline.Wait():
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
case <-c.ctx.Done():
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) ReadMsg(b []byte, oob []byte) (n, oobn int, addr netip.Addr, err error) {
|
||||
select {
|
||||
case packet := <-c.receiveChan:
|
||||
n = copy(b, packet.Buffer.Bytes())
|
||||
oobn = copy(oob, packet.Cmsg.Bytes())
|
||||
addr = packet.Addr
|
||||
packet.Buffer.Release()
|
||||
packet.Cmsg.Release()
|
||||
return
|
||||
case <-c.readDeadline.Wait():
|
||||
return 0, 0, netip.Addr{}, os.ErrDeadlineExceeded
|
||||
case <-c.ctx.Done():
|
||||
return 0, 0, netip.Addr{}, os.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) Write(b []byte) (n int, err error) {
|
||||
var identifier uint16
|
||||
if !c.destination.Is6() {
|
||||
icmpHdr := header.ICMPv4(b)
|
||||
identifier = icmpHdr.Ident()
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(b)
|
||||
identifier = icmpHdr.Ident()
|
||||
}
|
||||
|
||||
c.mappingAccess.Lock()
|
||||
if c.ctx.Err() != nil {
|
||||
return 0, c.ctx.Err()
|
||||
}
|
||||
conn, loaded := c.mapping[identifier]
|
||||
if !loaded {
|
||||
conn, err = connect(false, c.controlFunc, c.destination)
|
||||
if err != nil {
|
||||
c.mappingAccess.Unlock()
|
||||
return
|
||||
}
|
||||
go c.fetchResponse(conn.(*net.UDPConn), identifier)
|
||||
c.mapping[identifier] = conn
|
||||
}
|
||||
c.mappingAccess.Unlock()
|
||||
n, err = conn.Write(b)
|
||||
if err != nil {
|
||||
c.removeConn(conn.(*net.UDPConn), identifier)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) fetchResponse(conn *net.UDPConn, identifier uint16) {
|
||||
defer c.removeConn(conn, identifier)
|
||||
for {
|
||||
buffer := buf.NewPacket()
|
||||
cmsgBuffer := buf.NewSize(1024)
|
||||
n, oobN, _, addr, err := conn.ReadMsgUDPAddrPort(buffer.FreeBytes(), cmsgBuffer.FreeBytes())
|
||||
if err != nil {
|
||||
buffer.Release()
|
||||
cmsgBuffer.Release()
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
cmsgBuffer.Truncate(oobN)
|
||||
if !c.destination.Is6() {
|
||||
icmpHdr := header.ICMPv4(buffer.Bytes())
|
||||
icmpHdr.SetIdent(identifier)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
} else {
|
||||
icmpHdr := header.ICMPv6(buffer.Bytes())
|
||||
icmpHdr.SetIdent(identifier)
|
||||
// offload checksum here since we don't have source address here
|
||||
}
|
||||
select {
|
||||
case c.receiveChan <- &unprivilegedResponse{
|
||||
Buffer: buffer,
|
||||
Cmsg: cmsgBuffer,
|
||||
Addr: addr.Addr(),
|
||||
}:
|
||||
case <-c.ctx.Done():
|
||||
buffer.Release()
|
||||
cmsgBuffer.Release()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) removeConn(conn *net.UDPConn, identifier uint16) {
|
||||
c.mappingAccess.Lock()
|
||||
defer c.mappingAccess.Unlock()
|
||||
_ = conn.Close()
|
||||
delete(c.mapping, identifier)
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) Close() error {
|
||||
c.mappingAccess.Lock()
|
||||
defer c.mappingAccess.Unlock()
|
||||
c.cancel()
|
||||
for _, conn := range c.mapping {
|
||||
_ = conn.Close()
|
||||
}
|
||||
common.ClearMap(c.mapping)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) LocalAddr() net.Addr {
|
||||
return M.Socksaddr{}
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) RemoteAddr() net.Addr {
|
||||
return M.SocksaddrFrom(c.destination, 0).UDPAddr()
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) SetDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) SetReadDeadline(t time.Time) error {
|
||||
c.readDeadline.Set(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UnprivilegedConn) SetWriteDeadline(t time.Time) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
110
ping/socket_unix.go
Normal file
110
ping/socket_unix.go
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build unix
|
||||
|
||||
package ping
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"runtime"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
|
||||
var (
|
||||
network string
|
||||
fd int
|
||||
err error
|
||||
)
|
||||
if destination.Is4() {
|
||||
network = "ip4" // like std's netFD.ctrlNetwork
|
||||
if !privileged {
|
||||
fd, err = unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, unix.IPPROTO_ICMP)
|
||||
} else {
|
||||
fd, err = unix.Socket(unix.AF_INET, unix.SOCK_RAW, unix.IPPROTO_ICMP)
|
||||
}
|
||||
} else {
|
||||
network = "ip6" // like std's netFD.ctrlNetwork
|
||||
if !privileged {
|
||||
fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, unix.IPPROTO_ICMPV6)
|
||||
} else {
|
||||
fd, err = unix.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_ICMPV6)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "socket()")
|
||||
}
|
||||
file := os.NewFile(uintptr(fd), "datagram-oriented icmp")
|
||||
defer file.Close()
|
||||
if controlFunc != nil {
|
||||
var syscallConn syscall.RawConn
|
||||
syscallConn, err = file.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = controlFunc(network, destination.String(), syscallConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if destination.Is4() && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||
//err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTOS, 1)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
err = unix.SetsockoptInt(fd, unix.IPPROTO_IP, unix.IP_RECVTTL, 1)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "setsockopt()")
|
||||
}
|
||||
}
|
||||
if destination.Is6() {
|
||||
err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVHOPLIMIT, 1)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "setsockopt()")
|
||||
}
|
||||
err = unix.SetsockoptInt(fd, unix.IPPROTO_IPV6, unix.IPV6_RECVTCLASS, 1)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "setsockopt()")
|
||||
}
|
||||
}
|
||||
|
||||
var bindAddress netip.Addr
|
||||
if !destination.Is6() {
|
||||
bindAddress = netip.AddrFrom4([4]byte{})
|
||||
} else {
|
||||
bindAddress = netip.AddrFrom16([16]byte{})
|
||||
}
|
||||
|
||||
err = unix.Bind(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(bindAddress, 0)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" && !privileged {
|
||||
// When running in NetworkExtension on macOS, write to connected socket results in EPIPE.
|
||||
var packetConn net.PacketConn
|
||||
packetConn, err = net.FilePacketConn(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewBindPacketConn(packetConn, M.SocksaddrFrom(destination, 0).UDPAddr()), nil
|
||||
} else {
|
||||
err = unix.Connect(fd, M.AddrPortToSockaddr(netip.AddrPortFrom(destination, 0)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var conn net.Conn
|
||||
conn, err = net.FileConn(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
}
|
||||
38
ping/socket_windows.go
Normal file
38
ping/socket_windows.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func connect(privileged bool, controlFunc control.Func, destination netip.Addr) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
dialer.Control = controlFunc
|
||||
if destination.Is6() {
|
||||
dialer.Control = control.Append(dialer.Control, func(network, address string, conn syscall.RawConn) error {
|
||||
return control.Raw(conn, func(fd uintptr) error {
|
||||
err := windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_HOPLIMIT, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, IPV6_RECVTCLASS, 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
var network string
|
||||
if destination.Is4() {
|
||||
network = "ip4:icmp"
|
||||
} else {
|
||||
network = "ip6:ipv6-icmp"
|
||||
}
|
||||
return dialer.Dial(network, destination.String())
|
||||
}
|
||||
150
ping/source_rewriter.go
Normal file
150
ping/source_rewriter.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package ping
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
type SourceRewriter struct {
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
access sync.RWMutex
|
||||
sessions map[tun.DirectRouteSession]tun.DirectRouteContext
|
||||
sourceAddress map[uint16]netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
}
|
||||
|
||||
func NewSourceRewriter(ctx context.Context, logger logger.ContextLogger, inet4Address netip.Addr, inet6Address netip.Addr) *SourceRewriter {
|
||||
return &SourceRewriter{
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
sessions: make(map[tun.DirectRouteSession]tun.DirectRouteContext),
|
||||
sourceAddress: make(map[uint16]netip.Addr),
|
||||
inet4Address: inet4Address,
|
||||
inet6Address: inet6Address,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SourceRewriter) CreateSession(session tun.DirectRouteSession, context tun.DirectRouteContext) {
|
||||
m.access.Lock()
|
||||
m.sessions[session] = context
|
||||
m.access.Unlock()
|
||||
}
|
||||
|
||||
func (m *SourceRewriter) DeleteSession(session tun.DirectRouteSession) {
|
||||
m.access.Lock()
|
||||
delete(m.sessions, session)
|
||||
m.access.Unlock()
|
||||
}
|
||||
|
||||
func (m *SourceRewriter) RewritePacket(packet []byte) {
|
||||
var ipHdr header.Network
|
||||
var bindAddr netip.Addr
|
||||
switch header.IPVersion(packet) {
|
||||
case header.IPv4Version:
|
||||
ipHdr = header.IPv4(packet)
|
||||
bindAddr = m.inet4Address
|
||||
case header.IPv6Version:
|
||||
ipHdr = header.IPv6(packet)
|
||||
bindAddr = m.inet6Address
|
||||
default:
|
||||
return
|
||||
}
|
||||
sourceAddr := ipHdr.SourceAddr()
|
||||
ipHdr.SetSourceAddr(bindAddr)
|
||||
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
|
||||
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
|
||||
}
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
m.access.Lock()
|
||||
m.sourceAddress[icmpHdr.Ident()] = sourceAddr
|
||||
m.access.Unlock()
|
||||
m.logger.TraceContext(m.ctx, "write ICMPv4 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
m.access.Lock()
|
||||
m.sourceAddress[icmpHdr.Ident()] = sourceAddr
|
||||
m.access.Unlock()
|
||||
m.logger.TraceContext(m.ctx, "write ICMPv6 echo request from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SourceRewriter) WriteBack(packet []byte) (bool, error) {
|
||||
var ipHdr header.Network
|
||||
var routeSession tun.DirectRouteSession
|
||||
switch header.IPVersion(packet) {
|
||||
case header.IPv4Version:
|
||||
ipHdr = header.IPv4(packet)
|
||||
routeSession.Destination = ipHdr.SourceAddr()
|
||||
case header.IPv6Version:
|
||||
ipHdr = header.IPv6(packet)
|
||||
routeSession.Destination = ipHdr.SourceAddr()
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
m.access.Lock()
|
||||
ident := icmpHdr.Ident()
|
||||
source, loaded := m.sourceAddress[ident]
|
||||
if !loaded {
|
||||
m.access.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
delete(m.sourceAddress, icmpHdr.Ident())
|
||||
m.access.Unlock()
|
||||
routeSession.Source = source
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
m.access.Lock()
|
||||
ident := icmpHdr.Ident()
|
||||
source, loaded := m.sourceAddress[ident]
|
||||
if !loaded {
|
||||
m.access.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
delete(m.sourceAddress, icmpHdr.Ident())
|
||||
m.access.Unlock()
|
||||
routeSession.Source = source
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
m.access.RLock()
|
||||
context, loaded := m.sessions[routeSession]
|
||||
m.access.RUnlock()
|
||||
if !loaded {
|
||||
return false, nil
|
||||
}
|
||||
ipHdr.SetDestinationAddr(routeSession.Source)
|
||||
if ipHdr4, isIPv4 := ipHdr.(header.IPv4); isIPv4 {
|
||||
ipHdr4.SetChecksum(^ipHdr4.CalculateChecksum())
|
||||
}
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
icmpHdr := header.ICMPv4(ipHdr.Payload())
|
||||
m.logger.TraceContext(m.ctx, "read ICMPv4 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
icmpHdr := header.ICMPv6(ipHdr.Payload())
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
m.logger.TraceContext(m.ctx, "read ICMPv6 echo reply from ", ipHdr.SourceAddr(), " to ", ipHdr.DestinationAddr(), " id ", icmpHdr.Ident(), " seq ", icmpHdr.Sequence())
|
||||
}
|
||||
return true, context.WritePacket(packet)
|
||||
}
|
||||
@@ -143,12 +143,26 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
}
|
||||
}
|
||||
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
|
||||
Name: "prerouting_udp",
|
||||
Name: "prerouting_udp_icmp",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
ipProto := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetProto,
|
||||
}
|
||||
err = nft.AddSet(ipProto, []nftables.SetElement{
|
||||
{Key: []byte{unix.IPPROTO_UDP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMP}},
|
||||
{Key: []byte{unix.IPPROTO_ICMPV6}},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
@@ -157,10 +171,11 @@ func (r *autoRedirect) setupNFTables() error {
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_UDP},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: ipProto.ID,
|
||||
SetName: ipProto.Name,
|
||||
Invert: true,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
61
route_direct.go
Normal file
61
route_direct.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/contrab/freelru"
|
||||
"github.com/sagernet/sing/contrab/maphash"
|
||||
)
|
||||
|
||||
type DirectRouteDestination interface {
|
||||
WritePacket(packet *buf.Buffer) error
|
||||
Close() error
|
||||
IsClosed() bool
|
||||
}
|
||||
|
||||
type DirectRouteSession struct {
|
||||
// IPVersion uint8
|
||||
// Network uint8
|
||||
Source netip.Addr
|
||||
Destination netip.Addr
|
||||
}
|
||||
|
||||
type DirectRouteMapping struct {
|
||||
mapping freelru.Cache[DirectRouteSession, DirectRouteDestination]
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewDirectRouteMapping(timeout time.Duration) *DirectRouteMapping {
|
||||
mapping := common.Must1(freelru.NewSharded[DirectRouteSession, DirectRouteDestination](1024, maphash.NewHasher[DirectRouteSession]().Hash32))
|
||||
mapping.SetHealthCheck(func(session DirectRouteSession, action DirectRouteDestination) bool {
|
||||
if action != nil {
|
||||
return !action.IsClosed()
|
||||
}
|
||||
return true
|
||||
})
|
||||
mapping.SetOnEvict(func(session DirectRouteSession, action DirectRouteDestination) {
|
||||
if action != nil {
|
||||
action.Close()
|
||||
}
|
||||
})
|
||||
mapping.SetLifetime(timeout)
|
||||
return &DirectRouteMapping{mapping, timeout}
|
||||
}
|
||||
|
||||
func (m *DirectRouteMapping) Lookup(session DirectRouteSession, constructor func(timeout time.Duration) (DirectRouteDestination, error)) (DirectRouteDestination, error) {
|
||||
var (
|
||||
created DirectRouteDestination
|
||||
err error
|
||||
)
|
||||
action, _, ok := m.mapping.GetAndRefreshOrAdd(session, func() (DirectRouteDestination, bool) {
|
||||
created, err = constructor(m.timeout)
|
||||
return created, err == nil
|
||||
})
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
return action, nil
|
||||
}
|
||||
5
stack.go
5
stack.go
@@ -12,7 +12,10 @@ import (
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
var ErrDrop = E.New("drop connections by rule")
|
||||
var (
|
||||
ErrDrop = E.New("drop by rule")
|
||||
ErrReset = E.New("reset by rule")
|
||||
)
|
||||
|
||||
type Stack interface {
|
||||
Start() error
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/raw"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
@@ -28,6 +29,8 @@ const DefaultNIC tcpip.NICID = 1
|
||||
type GVisor struct {
|
||||
ctx context.Context
|
||||
tun GVisorTun
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
inet4LoopbackAddress []netip.Addr
|
||||
inet6LoopbackAddress []netip.Addr
|
||||
udpTimeout time.Duration
|
||||
@@ -52,9 +55,22 @@ func NewGVisor(
|
||||
return nil, E.New("gVisor stack is unsupported on current platform")
|
||||
}
|
||||
|
||||
var (
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
)
|
||||
if len(options.TunOptions.Inet4Address) > 0 {
|
||||
inet4Address = options.TunOptions.Inet4Address[0].Addr()
|
||||
}
|
||||
if len(options.TunOptions.Inet6Address) > 0 {
|
||||
inet6Address = options.TunOptions.Inet6Address[0].Addr()
|
||||
}
|
||||
|
||||
gStack := &GVisor{
|
||||
ctx: options.Context,
|
||||
tun: gTun,
|
||||
inet4Address: inet4Address,
|
||||
inet6Address: inet6Address,
|
||||
inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
|
||||
inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
|
||||
udpTimeout: options.UDPTimeout,
|
||||
@@ -71,12 +87,16 @@ func (t *GVisor) Start() error {
|
||||
return err
|
||||
}
|
||||
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
|
||||
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions)
|
||||
ipStack, err := NewGVisorStackWithOptions(linkEndpoint, nicOptions, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
|
||||
icmpForwarder := NewICMPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout)
|
||||
icmpForwarder.SetLocalAddresses(t.inet4Address, t.inet6Address)
|
||||
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket)
|
||||
t.stack = ipStack
|
||||
t.endpoint = linkEndpoint
|
||||
return nil
|
||||
@@ -111,11 +131,11 @@ func AddrFromAddress(address tcpip.Address) netip.Addr {
|
||||
}
|
||||
|
||||
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
||||
return NewGVisorStackWithOptions(ep, stack.NICOptions{})
|
||||
return NewGVisorStackWithOptions(ep, stack.NICOptions{}, false)
|
||||
}
|
||||
|
||||
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*stack.Stack, error) {
|
||||
ipStack := stack.New(stack.Options{
|
||||
func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions, allowRawEndpoint bool) (*stack.Stack, error) {
|
||||
stackOptions := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
@@ -126,7 +146,11 @@ func NewGVisorStackWithOptions(ep stack.LinkEndpoint, opts stack.NICOptions) (*s
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
})
|
||||
}
|
||||
if allowRawEndpoint {
|
||||
stackOptions.RawFactory = new(raw.EndpointFactory)
|
||||
}
|
||||
ipStack := stack.New(stackOptions)
|
||||
err := ipStack.CreateNICWithOptions(DefaultNIC, ep, opts)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
|
||||
244
stack_gvisor_icmp.go
Normal file
244
stack_gvisor_icmp.go
Normal file
@@ -0,0 +1,244 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header/parse"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type ICMPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
inet4Address netip.Addr
|
||||
inet6Address netip.Addr
|
||||
handler Handler
|
||||
mapping *DirectRouteMapping
|
||||
}
|
||||
|
||||
func NewICMPForwarder(
|
||||
ctx context.Context,
|
||||
stack *stack.Stack,
|
||||
handler Handler,
|
||||
timeout time.Duration,
|
||||
) *ICMPForwarder {
|
||||
return &ICMPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
handler: handler,
|
||||
mapping: NewDirectRouteMapping(timeout),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *ICMPForwarder) SetLocalAddresses(inet4Address, inet6Address netip.Addr) {
|
||||
f.inet4Address = inet4Address
|
||||
f.inet6Address = inet6Address
|
||||
}
|
||||
|
||||
func (f *ICMPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
ipHdr := header.IPv4(pkt.NetworkHeader().Slice())
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().Slice())
|
||||
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
|
||||
return false
|
||||
}
|
||||
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
|
||||
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
|
||||
if destinationAddr != f.inet4Address {
|
||||
action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
|
||||
return f.handler.PrepareConnection(
|
||||
N.NetworkICMP,
|
||||
M.SocksaddrFrom(sourceAddr, 0),
|
||||
M.SocksaddrFrom(destinationAddr, 0),
|
||||
&ICMPBackWriter{
|
||||
stack: f.stack,
|
||||
packet: pkt,
|
||||
source: ipHdr.SourceAddress(),
|
||||
sourceNetwork: header.IPv4ProtocolNumber,
|
||||
},
|
||||
timeout,
|
||||
)
|
||||
})
|
||||
if errors.Is(err, ErrReset) {
|
||||
gWriteUnreachable(f.stack, pkt)
|
||||
return true
|
||||
} else if errors.Is(err, ErrDrop) {
|
||||
return true
|
||||
}
|
||||
if action != nil {
|
||||
// TODO: handle error
|
||||
_ = icmpWritePacketBuffer(action, pkt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
icmpHdr.SetType(header.ICMPv4EchoReply)
|
||||
sourceAddress := ipHdr.SourceAddress()
|
||||
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
|
||||
ipHdr.SetDestinationAddress(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], pkt.Data().Checksum()))
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
|
||||
if gErr != nil {
|
||||
// TODO: log error
|
||||
return true
|
||||
}
|
||||
route, gErr := f.stack.FindRoute(
|
||||
DefaultNIC,
|
||||
id.LocalAddress,
|
||||
id.RemoteAddress,
|
||||
header.IPv6ProtocolNumber,
|
||||
false,
|
||||
)
|
||||
if gErr != nil {
|
||||
// TODO: log error
|
||||
return true
|
||||
}
|
||||
defer route.Release()
|
||||
outgoingEP.(ipv4.ExportedEndpoint).WritePacketDirect(route, pkt)
|
||||
return true
|
||||
} else {
|
||||
ipHdr := header.IPv6(pkt.NetworkHeader().Slice())
|
||||
icmpHdr := header.ICMPv6(pkt.TransportHeader().Slice())
|
||||
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
|
||||
return false
|
||||
}
|
||||
sourceAddr := M.AddrFromIP(ipHdr.SourceAddressSlice())
|
||||
destinationAddr := M.AddrFromIP(ipHdr.DestinationAddressSlice())
|
||||
if destinationAddr != f.inet6Address {
|
||||
action, err := f.mapping.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
|
||||
return f.handler.PrepareConnection(
|
||||
N.NetworkICMP,
|
||||
M.SocksaddrFrom(sourceAddr, 0),
|
||||
M.SocksaddrFrom(destinationAddr, 0),
|
||||
&ICMPBackWriter{
|
||||
stack: f.stack,
|
||||
packet: pkt,
|
||||
source: ipHdr.SourceAddress(),
|
||||
sourceNetwork: header.IPv6ProtocolNumber,
|
||||
},
|
||||
timeout,
|
||||
)
|
||||
})
|
||||
if errors.Is(err, ErrReset) {
|
||||
gWriteUnreachable(f.stack, pkt)
|
||||
return true
|
||||
} else if errors.Is(err, ErrDrop) {
|
||||
return true
|
||||
}
|
||||
if action != nil {
|
||||
// TODO: handle error
|
||||
pkt.IncRef()
|
||||
_ = icmpWritePacketBuffer(action, pkt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
icmpHdr.SetType(header.ICMPv6EchoReply)
|
||||
sourceAddress := ipHdr.SourceAddress()
|
||||
ipHdr.SetSourceAddress(ipHdr.DestinationAddress())
|
||||
ipHdr.SetDestinationAddress(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddress(),
|
||||
Dst: ipHdr.DestinationAddress(),
|
||||
PayloadCsum: pkt.Data().Checksum(),
|
||||
PayloadLen: pkt.Data().Size(),
|
||||
}))
|
||||
outgoingEP, gErr := f.stack.GetNetworkEndpoint(DefaultNIC, header.IPv4ProtocolNumber)
|
||||
if gErr != nil {
|
||||
// TODO: log error
|
||||
return true
|
||||
}
|
||||
route, gErr := f.stack.FindRoute(
|
||||
DefaultNIC,
|
||||
id.LocalAddress,
|
||||
id.RemoteAddress,
|
||||
header.IPv6ProtocolNumber,
|
||||
false,
|
||||
)
|
||||
if gErr != nil {
|
||||
// TODO: log error
|
||||
return true
|
||||
}
|
||||
defer route.Release()
|
||||
outgoingEP.(ipv6.ExportedEndpoint).WritePacketDirect(route, pkt)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
type ICMPBackWriter struct {
|
||||
access sync.Mutex
|
||||
stack *stack.Stack
|
||||
packet *stack.PacketBuffer
|
||||
source tcpip.Address
|
||||
sourceNetwork tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (w *ICMPBackWriter) WritePacket(p []byte) error {
|
||||
if w.sourceNetwork == header.IPv4ProtocolNumber {
|
||||
route, err := w.stack.FindRoute(
|
||||
DefaultNIC,
|
||||
header.IPv4(p).SourceAddress(),
|
||||
w.source,
|
||||
w.sourceNetwork,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return gonet.TranslateNetstackError(err)
|
||||
}
|
||||
defer route.Release()
|
||||
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(p),
|
||||
})
|
||||
defer packet.DecRef()
|
||||
parse.IPv4(packet)
|
||||
err = route.WritePacketDirect(packet)
|
||||
if err != nil {
|
||||
return gonet.TranslateNetstackError(err)
|
||||
}
|
||||
} else {
|
||||
route, err := w.stack.FindRoute(
|
||||
DefaultNIC,
|
||||
header.IPv6(p).SourceAddress(),
|
||||
w.source,
|
||||
w.sourceNetwork,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return gonet.TranslateNetstackError(err)
|
||||
}
|
||||
defer route.Release()
|
||||
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(p),
|
||||
})
|
||||
parse.IPv6(packet)
|
||||
defer packet.DecRef()
|
||||
err = route.WritePacketDirect(packet)
|
||||
if err != nil {
|
||||
return gonet.TranslateNetstackError(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func icmpWritePacketBuffer(action DirectRouteDestination, packetBuffer *stack.PacketBuffer) error {
|
||||
packetSlice := packetBuffer.NetworkHeader().Slice()
|
||||
packetSlice = append(packetSlice, packetBuffer.TransportHeader().Slice()...)
|
||||
packetSlice = append(packetSlice, packetBuffer.Data().AsRange().ToSlice()...)
|
||||
return action.WritePacket(buf.As(packetSlice).ToOwned())
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
|
||||
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
|
||||
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
f.tun.WritePacket(pkt)
|
||||
@@ -66,7 +66,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
|
||||
ipHdr.SetSourceAddress(inet6LoopbackAddress)
|
||||
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
tcpHdr.SetChecksum(^checksum.Combine(pkt.Data().Checksum(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
f.tun.WritePacket(pkt)
|
||||
@@ -79,7 +79,7 @@ func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
|
||||
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
|
||||
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
|
||||
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
|
||||
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
|
||||
_, pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination, nil, 0)
|
||||
if pErr != nil {
|
||||
r.Complete(!errors.Is(pErr, ErrDrop))
|
||||
return
|
||||
|
||||
@@ -58,7 +58,7 @@ func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.Pac
|
||||
func rangeIterate(r stack.Range, fn func(*buffer.View))
|
||||
|
||||
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
|
||||
_, pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination, nil, 0)
|
||||
if pErr != nil {
|
||||
if !errors.Is(pErr, ErrDrop) {
|
||||
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))
|
||||
|
||||
@@ -237,7 +237,7 @@ func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
|
||||
pkt.DecRef()
|
||||
return
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
|
||||
writeBack, err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -259,7 +259,7 @@ func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
|
||||
m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt)
|
||||
pkt.DecRef()
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
|
||||
writeBack, err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
167
stack_system.go
167
stack_system.go
@@ -31,10 +31,10 @@ type System struct {
|
||||
logger logger.Logger
|
||||
inet4Prefixes []netip.Prefix
|
||||
inet6Prefixes []netip.Prefix
|
||||
inet4ServerAddress netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6ServerAddress netip.Addr
|
||||
inet4NextAddress netip.Addr
|
||||
inet6Address netip.Addr
|
||||
inet6NextAddress netip.Addr
|
||||
broadcastAddr netip.Addr
|
||||
inet4LoopbackAddress []netip.Addr
|
||||
inet6LoopbackAddress []netip.Addr
|
||||
@@ -45,6 +45,7 @@ type System struct {
|
||||
tcpPort6 uint16
|
||||
tcpNat *TCPNat
|
||||
udpNat *udpnat.Service
|
||||
directNat *DirectRouteMapping
|
||||
bindInterface bool
|
||||
interfaceFinder control.InterfaceFinder
|
||||
frontHeadroom int
|
||||
@@ -81,17 +82,17 @@ func NewSystem(options StackOptions) (Stack, error) {
|
||||
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
|
||||
return nil, E.New("need one more IPv4 address in first prefix for system stack")
|
||||
}
|
||||
stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
|
||||
stack.inet4Address = stack.inet4ServerAddress.Next()
|
||||
stack.inet4Address = options.TunOptions.Inet4Address[0].Addr()
|
||||
stack.inet4NextAddress = stack.inet4Address.Next()
|
||||
}
|
||||
if len(options.TunOptions.Inet6Address) > 0 {
|
||||
if !HasNextAddress(options.TunOptions.Inet6Address[0], 1) {
|
||||
return nil, E.New("need one more IPv6 address in first prefix for system stack")
|
||||
}
|
||||
stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
|
||||
stack.inet6Address = stack.inet6ServerAddress.Next()
|
||||
stack.inet6Address = options.TunOptions.Inet6Address[0].Addr()
|
||||
stack.inet6NextAddress = stack.inet6Address.Next()
|
||||
}
|
||||
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
|
||||
if !stack.inet4NextAddress.IsValid() && !stack.inet6NextAddress.IsValid() {
|
||||
return nil, E.New("missing interface address")
|
||||
}
|
||||
return stack, nil
|
||||
@@ -127,9 +128,9 @@ func (s *System) start() error {
|
||||
}
|
||||
var tcpListener net.Listener
|
||||
var err error
|
||||
if s.inet4Address.IsValid() {
|
||||
if s.inet4NextAddress.IsValid() {
|
||||
for i := 0; i < 3; i++ {
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4Address.String(), "0"))
|
||||
if !retryableListenError(err) {
|
||||
break
|
||||
}
|
||||
@@ -142,9 +143,9 @@ func (s *System) start() error {
|
||||
s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
|
||||
go s.acceptLoop(tcpListener)
|
||||
}
|
||||
if s.inet6Address.IsValid() {
|
||||
if s.inet6NextAddress.IsValid() {
|
||||
for i := 0; i < 3; i++ {
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6Address.String(), "0"))
|
||||
if !retryableListenError(err) {
|
||||
break
|
||||
}
|
||||
@@ -159,6 +160,7 @@ func (s *System) start() error {
|
||||
}
|
||||
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
|
||||
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false)
|
||||
s.directNat = NewDirectRouteMapping(s.udpTimeout)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -361,7 +363,10 @@ func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
|
||||
writeBack = false
|
||||
err = s.processIPv4UDP(ipHdr, ipHdr.Payload())
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
|
||||
writeBack, err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
if err != nil {
|
||||
writeBack = false
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -377,7 +382,10 @@ func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
|
||||
case header.UDPProtocolNumber:
|
||||
err = s.processIPv6UDP(ipHdr, ipHdr.Payload())
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
|
||||
writeBack, err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
if err != nil {
|
||||
writeBack = false
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -387,7 +395,7 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
|
||||
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
|
||||
if !destination.Addr().IsGlobalUnicast() {
|
||||
return false, nil
|
||||
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
|
||||
} else if source.Addr() == s.inet4Address && source.Port() == s.tcpPort {
|
||||
session := s.tcpNat.LookupBack(destination.Port())
|
||||
if session == nil {
|
||||
return false, E.New("ipv4: tcp: session not found: ", destination.Port())
|
||||
@@ -415,21 +423,19 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
|
||||
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet4Address)
|
||||
ipHdr.SetSourceAddr(s.inet4NextAddress)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
|
||||
ipHdr.SetDestinationAddr(s.inet4Address)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort)
|
||||
}
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
tcpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return true, nil
|
||||
}
|
||||
@@ -470,7 +476,6 @@ func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) erro
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
@@ -485,7 +490,7 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
|
||||
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
|
||||
if !destination.Addr().IsGlobalUnicast() {
|
||||
return false, nil
|
||||
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
|
||||
} else if source.Addr() == s.inet6Address && source.Port() == s.tcpPort6 {
|
||||
session := s.tcpNat.LookupBack(destination.Port())
|
||||
if session == nil {
|
||||
return false, E.New("ipv6: tcp: session not found: ", destination.Port())
|
||||
@@ -513,14 +518,13 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
|
||||
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet6Address)
|
||||
ipHdr.SetSourceAddr(s.inet6NextAddress)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
|
||||
ipHdr.SetDestinationAddr(s.inet6Address)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort6)
|
||||
}
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
@@ -601,7 +605,7 @@ func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
|
||||
}
|
||||
|
||||
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
|
||||
_, pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination, nil, 0)
|
||||
if pErr != nil {
|
||||
if !errors.Is(pErr, ErrDrop) {
|
||||
if source.IsIPv4() {
|
||||
@@ -643,18 +647,40 @@ func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socks
|
||||
return true, s.ctx, writer, nil
|
||||
}
|
||||
|
||||
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error {
|
||||
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) (bool, error) {
|
||||
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
sourceAddr := ipHdr.SourceAddr()
|
||||
destinationAddr := ipHdr.DestinationAddr()
|
||||
if destinationAddr != s.inet4Address {
|
||||
action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
|
||||
return s.handler.PrepareConnection(
|
||||
N.NetworkICMP,
|
||||
M.SocksaddrFrom(sourceAddr, 0),
|
||||
M.SocksaddrFrom(destinationAddr, 0),
|
||||
&systemICMPDirectPacketWriter4{s.tun, s.frontHeadroom + PacketOffset, sourceAddr},
|
||||
timeout,
|
||||
)
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrReset) {
|
||||
return false, s.rejectIPv4WithICMP(ipHdr, header.ICMPv4HostUnreachable)
|
||||
} else if errors.Is(err, ErrDrop) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
if action != nil {
|
||||
return false, action.WritePacket(buf.As(ipHdr).ToOwned())
|
||||
}
|
||||
}
|
||||
icmpHdr.SetType(header.ICMPv4EchoReply)
|
||||
sourceAddress := ipHdr.SourceAddr()
|
||||
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
||||
ipHdr.SetDestinationAddr(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
ipHdr.SetChecksum(0)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, 0))
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error {
|
||||
@@ -686,7 +712,7 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
|
||||
icmpHdr := header.ICMPv4(newIPHdr.Payload())
|
||||
icmpHdr.SetType(header.ICMPv4DstUnreachable)
|
||||
icmpHdr.SetCode(code)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(ipHdr.Payload(), 0)))
|
||||
copy(icmpHdr.Payload(), payload)
|
||||
if PacketOffset > 0 {
|
||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
|
||||
@@ -696,9 +722,30 @@ func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) e
|
||||
return common.Error(s.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error {
|
||||
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) (bool, error) {
|
||||
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
sourceAddr := ipHdr.SourceAddr()
|
||||
destinationAddr := ipHdr.DestinationAddr()
|
||||
if destinationAddr != s.inet6Address {
|
||||
action, err := s.directNat.Lookup(DirectRouteSession{Source: sourceAddr, Destination: destinationAddr}, func(timeout time.Duration) (DirectRouteDestination, error) {
|
||||
return s.handler.PrepareConnection(
|
||||
N.NetworkICMP,
|
||||
M.SocksaddrFrom(sourceAddr, 0),
|
||||
M.SocksaddrFrom(destinationAddr, 0),
|
||||
&systemICMPDirectPacketWriter6{s.tun, s.frontHeadroom + PacketOffset, sourceAddr},
|
||||
timeout,
|
||||
)
|
||||
})
|
||||
if errors.Is(err, ErrReset) {
|
||||
return false, s.rejectIPv6WithICMP(ipHdr, header.ICMPv6AddressUnreachable)
|
||||
} else if errors.Is(err, ErrDrop) {
|
||||
return false, nil
|
||||
}
|
||||
if action != nil {
|
||||
return false, action.WritePacket(buf.As(ipHdr).ToOwned())
|
||||
}
|
||||
}
|
||||
icmpHdr.SetType(header.ICMPv6EchoReply)
|
||||
sourceAddress := ipHdr.SourceAddr()
|
||||
@@ -706,10 +753,10 @@ func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error
|
||||
ipHdr.SetDestinationAddr(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddress(),
|
||||
Dst: ipHdr.DestinationAddress(),
|
||||
Src: ipHdr.SourceAddressSlice(),
|
||||
Dst: ipHdr.DestinationAddressSlice(),
|
||||
}))
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error {
|
||||
@@ -742,8 +789,8 @@ func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) e
|
||||
icmpHdr.SetCode(code)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
|
||||
Src: newIPHdr.SourceAddress(),
|
||||
Dst: newIPHdr.DestinationAddress(),
|
||||
Src: newIPHdr.SourceAddressSlice(),
|
||||
Dst: newIPHdr.DestinationAddressSlice(),
|
||||
PayloadCsum: checksum.Checksum(payload, 0),
|
||||
PayloadLen: len(payload),
|
||||
}))
|
||||
@@ -779,14 +826,12 @@ func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
udpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
@@ -820,7 +865,6 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(udpLen)
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
@@ -834,3 +878,46 @@ func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.S
|
||||
}
|
||||
return common.Error(w.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
type systemICMPDirectPacketWriter4 struct {
|
||||
tun Tun
|
||||
frontHeadroom int
|
||||
source netip.Addr
|
||||
}
|
||||
|
||||
func (w *systemICMPDirectPacketWriter4) WritePacket(p []byte) error {
|
||||
newPacket := buf.NewSize(w.frontHeadroom + len(p))
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(w.frontHeadroom, 0)
|
||||
newPacket.Write(p)
|
||||
ipHdr := header.IPv4(newPacket.Bytes())
|
||||
ipHdr.SetDestinationAddr(w.source)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
} else {
|
||||
newPacket.Advance(-w.frontHeadroom)
|
||||
}
|
||||
return common.Error(w.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
type systemICMPDirectPacketWriter6 struct {
|
||||
tun Tun
|
||||
frontHeadroom int
|
||||
source netip.Addr
|
||||
}
|
||||
|
||||
func (w *systemICMPDirectPacketWriter6) WritePacket(p []byte) error {
|
||||
newPacket := buf.NewSize(w.frontHeadroom + len(p))
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(w.frontHeadroom, 0)
|
||||
newPacket.Write(p)
|
||||
ipHdr := header.IPv6(newPacket.Bytes())
|
||||
ipHdr.SetDestinationAddr(w.source)
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||
} else {
|
||||
newPacket.Advance(-w.frontHeadroom)
|
||||
}
|
||||
return common.Error(w.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
)
|
||||
|
||||
type TCPNat struct {
|
||||
timeout time.Duration
|
||||
portIndex uint16
|
||||
portAccess sync.RWMutex
|
||||
addrAccess sync.RWMutex
|
||||
@@ -19,6 +20,7 @@ type TCPNat struct {
|
||||
}
|
||||
|
||||
type TCPSession struct {
|
||||
sync.Mutex
|
||||
Source netip.AddrPort
|
||||
Destination netip.AddrPort
|
||||
LastActive time.Time
|
||||
@@ -26,38 +28,41 @@ type TCPSession struct {
|
||||
|
||||
func NewNat(ctx context.Context, timeout time.Duration) *TCPNat {
|
||||
natMap := &TCPNat{
|
||||
timeout: timeout,
|
||||
portIndex: 10000,
|
||||
addrMap: make(map[netip.AddrPort]uint16),
|
||||
portMap: make(map[uint16]*TCPSession),
|
||||
}
|
||||
go natMap.loopCheckTimeout(ctx, timeout)
|
||||
go natMap.loopCheckTimeout(ctx)
|
||||
return natMap
|
||||
}
|
||||
|
||||
func (n *TCPNat) loopCheckTimeout(ctx context.Context, timeout time.Duration) {
|
||||
ticker := time.NewTicker(timeout)
|
||||
func (n *TCPNat) loopCheckTimeout(ctx context.Context) {
|
||||
ticker := time.NewTicker(n.timeout)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
n.checkTimeout(timeout)
|
||||
n.checkTimeout()
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TCPNat) checkTimeout(timeout time.Duration) {
|
||||
func (n *TCPNat) checkTimeout() {
|
||||
now := time.Now()
|
||||
n.portAccess.Lock()
|
||||
defer n.portAccess.Unlock()
|
||||
n.addrAccess.Lock()
|
||||
defer n.addrAccess.Unlock()
|
||||
for natPort, session := range n.portMap {
|
||||
if now.Sub(session.LastActive) > timeout {
|
||||
session.Lock()
|
||||
if now.Sub(session.LastActive) > n.timeout {
|
||||
delete(n.addrMap, session.Source)
|
||||
delete(n.portMap, natPort)
|
||||
}
|
||||
session.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +71,11 @@ func (n *TCPNat) LookupBack(port uint16) *TCPSession {
|
||||
session := n.portMap[port]
|
||||
n.portAccess.RUnlock()
|
||||
if session != nil {
|
||||
session.LastActive = time.Now()
|
||||
session.Lock()
|
||||
if time.Since(session.LastActive) > time.Second {
|
||||
session.LastActive = time.Now()
|
||||
}
|
||||
session.Unlock()
|
||||
}
|
||||
return session
|
||||
}
|
||||
@@ -78,7 +87,7 @@ func (n *TCPNat) Lookup(source netip.AddrPort, destination netip.AddrPort, handl
|
||||
if loaded {
|
||||
return port, nil
|
||||
}
|
||||
pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination))
|
||||
_, pErr := handler.PrepareConnection(N.NetworkTCP, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), nil, 0)
|
||||
if pErr != nil {
|
||||
return 0, pErr
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
func PacketIPVersion(packet []byte) int {
|
||||
@@ -13,6 +14,7 @@ func PacketIPVersion(packet []byte) int {
|
||||
|
||||
func PacketFillHeader(packet []byte, ipVersion int) {
|
||||
if PacketOffset > 0 {
|
||||
common.ClearArray(packet[:3])
|
||||
switch ipVersion {
|
||||
case header.IPv4Version:
|
||||
packet[3] = syscall.AF_INET
|
||||
|
||||
13
tun.go
13
tun.go
@@ -7,6 +7,7 @@ import (
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
@@ -18,11 +19,21 @@ import (
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr) error
|
||||
PrepareConnection(
|
||||
network string,
|
||||
source M.Socksaddr,
|
||||
destination M.Socksaddr,
|
||||
routeContext DirectRouteContext,
|
||||
timeout time.Duration,
|
||||
) (DirectRouteDestination, error)
|
||||
N.TCPConnectionHandlerEx
|
||||
N.UDPConnectionHandlerEx
|
||||
}
|
||||
|
||||
type DirectRouteContext interface {
|
||||
WritePacket(packet []byte) error
|
||||
}
|
||||
|
||||
type Tun interface {
|
||||
io.ReadWriter
|
||||
Name() (string, error)
|
||||
|
||||
18
tun_linux.go
18
tun_linux.go
@@ -816,14 +816,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p4 && !t.options.StrictRoute {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority
|
||||
it.IPProto = syscall.IPPROTO_ICMP
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET
|
||||
rules = append(rules, it)
|
||||
}
|
||||
if p6 {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority6
|
||||
@@ -834,16 +826,6 @@ func (t *NativeTun) rules() []*netlink.Rule {
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
}
|
||||
|
||||
if p6 && !t.options.StrictRoute {
|
||||
it = netlink.NewRule()
|
||||
it.Priority = priority6
|
||||
it.IPProto = syscall.IPPROTO_ICMPV6
|
||||
it.Goto = nopPriority
|
||||
it.Family = unix.AF_INET6
|
||||
rules = append(rules, it)
|
||||
priority6++
|
||||
}
|
||||
}
|
||||
if p4 {
|
||||
it = netlink.NewRule()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"github.com/sagernet/sing-tun/internal/winsys"
|
||||
"github.com/sagernet/sing-tun/internal/wintun"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/windnsapi"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user