Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ead31a5b8d |
19
.github/renovate.json
vendored
19
.github/renovate.json
vendored
@@ -1,19 +0,0 @@
|
||||
{
|
||||
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
|
||||
"commitMessagePrefix": "[dependencies]",
|
||||
"extends": [
|
||||
"config:base",
|
||||
":disableRateLimiting"
|
||||
],
|
||||
"golang": {
|
||||
"enabled": false
|
||||
},
|
||||
"packageRules": [
|
||||
{
|
||||
"matchManagers": [
|
||||
"github-actions"
|
||||
],
|
||||
"groupName": "github-actions"
|
||||
}
|
||||
]
|
||||
}
|
||||
44
.github/workflows/debug.yml
vendored
Normal file
44
.github/workflows/debug.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
name: Debug build
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Debug build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Get latest go version
|
||||
id: version
|
||||
run: |
|
||||
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v2
|
||||
with:
|
||||
go-version: ${{ steps.version.outputs.go_version }}
|
||||
- name: Add cache to Go proxy
|
||||
run: |
|
||||
version=`git rev-parse HEAD`
|
||||
mkdir build
|
||||
pushd build
|
||||
go mod init build
|
||||
go get -v github.com/sagernet/sing-tun@$version
|
||||
popd
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -v .
|
||||
40
.github/workflows/lint.yml
vendored
40
.github/workflows/lint.yml
vendored
@@ -1,40 +0,0 @@
|
||||
name: lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/lint.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
- name: Cache go module
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/go/pkg/mod
|
||||
key: go-${{ hashFiles('**/go.sum') }}
|
||||
- name: golangci-lint
|
||||
uses: golangci/golangci-lint-action@v6
|
||||
with:
|
||||
version: latest
|
||||
args: .
|
||||
112
.github/workflows/test.yml
vendored
112
.github/workflows/test.yml
vendored
@@ -1,112 +0,0 @@
|
||||
name: test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
paths-ignore:
|
||||
- '**.md'
|
||||
- '.github/**'
|
||||
- '!.github/workflows/debug.yml'
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Linux
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go120:
|
||||
name: Linux (Go 1.20)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.20
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go121:
|
||||
name: Linux (Go 1.21)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.21
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_go122:
|
||||
name: Linux (Go 1.22)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ~1.22
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_windows:
|
||||
name: Windows
|
||||
runs-on: windows-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
build_darwin:
|
||||
name: macOS
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Setup Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: ^1.23
|
||||
continue-on-error: true
|
||||
- name: Build
|
||||
run: |
|
||||
make test
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,2 @@
|
||||
/.idea/
|
||||
/vendor/
|
||||
.DS_Store
|
||||
|
||||
@@ -3,18 +3,14 @@ linters:
|
||||
enable:
|
||||
- gofumpt
|
||||
- govet
|
||||
- gci
|
||||
# - gci
|
||||
- staticcheck
|
||||
- paralleltest
|
||||
- ineffassign
|
||||
|
||||
linters-settings:
|
||||
gci:
|
||||
custom-order: true
|
||||
sections:
|
||||
- standard
|
||||
- prefix(github.com/sagernet/)
|
||||
- default
|
||||
|
||||
run:
|
||||
go: "1.23"
|
||||
# gci:
|
||||
# sections:
|
||||
# - standard
|
||||
# - prefix(github.com/sagernet/sing)
|
||||
# - default
|
||||
staticcheck:
|
||||
go: '1.19'
|
||||
|
||||
5
Makefile
5
Makefile
@@ -5,7 +5,6 @@ build:
|
||||
GOOS=linux GOARCH=arm64 go build -v -tags with_gvisor .
|
||||
GOOS=linux GOARCH=386 go build -v -tags with_gvisor .
|
||||
GOOS=linux GOARCH=arm go build -v -tags with_gvisor .
|
||||
GOOS=android GOARCH=arm64 go build -v -tags with_gvisor .
|
||||
GOOS=windows GOARCH=amd64 go build -v -tags with_gvisor .
|
||||
|
||||
fmt:
|
||||
@@ -28,6 +27,4 @@ lint_install:
|
||||
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
test:
|
||||
go build -v .
|
||||
go test -bench=. ./internal/checksum_test
|
||||
#go test -v .
|
||||
go test -v .
|
||||
33
go.mod
33
go.mod
@@ -1,28 +1,21 @@
|
||||
module github.com/sagernet/sing-tun
|
||||
|
||||
go 1.20
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/go-ole/go-ole v1.3.0
|
||||
github.com/google/btree v1.1.3
|
||||
github.com/sagernet/fswatch v0.1.1
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
|
||||
github.com/sagernet/nftables v0.3.0-beta.4
|
||||
github.com/sagernet/sing v0.6.0-beta.2
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
|
||||
golang.org/x/net v0.26.0
|
||||
golang.org/x/sys v0.26.0
|
||||
github.com/fsnotify/fsnotify v1.6.0
|
||||
github.com/go-ole/go-ole v1.2.6
|
||||
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61
|
||||
github.com/sagernet/gvisor v0.0.0-20230626113001-7d91dec8c61c
|
||||
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97
|
||||
github.com/sagernet/sing v0.2.7
|
||||
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9
|
||||
golang.org/x/net v0.11.0
|
||||
golang.org/x/sys v0.9.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/fsnotify/fsnotify v1.7.0 // indirect
|
||||
github.com/google/go-cmp v0.6.0 // indirect
|
||||
github.com/josharian/native v1.1.0 // indirect
|
||||
github.com/mdlayher/netlink v1.7.2 // indirect
|
||||
github.com/mdlayher/socket v0.4.1 // indirect
|
||||
github.com/vishvananda/netns v0.0.4 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/time v0.7.0 // indirect
|
||||
github.com/google/btree v1.1.2 // indirect
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
|
||||
golang.org/x/time v0.3.0 // indirect
|
||||
)
|
||||
|
||||
72
go.sum
72
go.sum
@@ -1,43 +1,29 @@
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
|
||||
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
|
||||
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
|
||||
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
|
||||
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
|
||||
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
|
||||
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
|
||||
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
|
||||
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
|
||||
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
|
||||
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
|
||||
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ=
|
||||
github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
|
||||
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
|
||||
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
|
||||
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
|
||||
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
|
||||
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
|
||||
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
|
||||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
|
||||
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA=
|
||||
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms=
|
||||
github.com/sagernet/gvisor v0.0.0-20230626113001-7d91dec8c61c h1:YzXn4fVpRR9SL/l8S/30iYwNZwlBZY7FZvg8eHRNMdc=
|
||||
github.com/sagernet/gvisor v0.0.0-20230626113001-7d91dec8c61c/go.mod h1:1JUiV7nGuf++YFm9eWZ8q2lrwHmhcUGzptMl/vL1+LA=
|
||||
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
|
||||
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
|
||||
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
|
||||
github.com/sagernet/sing v0.2.7 h1:cOy0FfPS8q7m0aJ51wS7LRQAGc9wF+fWhHtBDj99wy8=
|
||||
github.com/sagernet/sing v0.2.7/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
|
||||
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
|
||||
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
|
||||
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
|
||||
golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
|
||||
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
|
||||
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
|
||||
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
|
||||
209
gvisor.go
Normal file
209
gvisor.go
Normal file
@@ -0,0 +1,209 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/gvisor/pkg/waiter"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/canceler"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
)
|
||||
|
||||
const WithGVisor = true
|
||||
|
||||
const defaultNIC tcpip.NICID = 1
|
||||
|
||||
type GVisor struct {
|
||||
ctx context.Context
|
||||
tun GVisorTun
|
||||
tunMtu uint32
|
||||
endpointIndependentNat bool
|
||||
udpTimeout int64
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
stack *stack.Stack
|
||||
endpoint stack.LinkEndpoint
|
||||
}
|
||||
|
||||
type GVisorTun interface {
|
||||
Tun
|
||||
NewEndpoint() (stack.LinkEndpoint, error)
|
||||
}
|
||||
|
||||
func NewGVisor(
|
||||
options StackOptions,
|
||||
) (Stack, error) {
|
||||
gTun, isGTun := options.Tun.(GVisorTun)
|
||||
if !isGTun {
|
||||
return nil, E.New("gVisor stack is unsupported on current platform")
|
||||
}
|
||||
|
||||
gStack := &GVisor{
|
||||
ctx: options.Context,
|
||||
tun: gTun,
|
||||
tunMtu: options.MTU,
|
||||
endpointIndependentNat: options.EndpointIndependentNat,
|
||||
udpTimeout: options.UDPTimeout,
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
}
|
||||
return gStack, nil
|
||||
}
|
||||
|
||||
func (t *GVisor) Start() error {
|
||||
linkEndpoint, err := t.tun.NewEndpoint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ipStack := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
})
|
||||
tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint)
|
||||
if tErr != nil {
|
||||
return E.New("create nic: ", wrapStackError(tErr))
|
||||
}
|
||||
ipStack.SetRouteTable([]tcpip.Route{
|
||||
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
|
||||
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
|
||||
})
|
||||
ipStack.SetSpoofing(defaultNIC, true)
|
||||
ipStack.SetPromiscuousMode(defaultNIC, true)
|
||||
bufSize := 20 * 1024
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
|
||||
Min: 1,
|
||||
Default: bufSize,
|
||||
Max: bufSize,
|
||||
})
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
|
||||
Min: 1,
|
||||
Default: bufSize,
|
||||
Max: bufSize,
|
||||
})
|
||||
sOpt := tcpip.TCPSACKEnabled(true)
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
|
||||
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
|
||||
|
||||
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
handshakeCtx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
wq.Notify(wq.Events())
|
||||
case <-handshakeCtx.Done():
|
||||
}
|
||||
}()
|
||||
endpoint, err := r.CreateEndpoint(&wq)
|
||||
cancel()
|
||||
if err != nil {
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
r.Complete(false)
|
||||
endpoint.SocketOptions().SetKeepAlive(true)
|
||||
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
|
||||
endpoint.SetSockOpt(&keepAliveIdle)
|
||||
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
|
||||
endpoint.SetSockOpt(&keepAliveInterval)
|
||||
tcpConn := gonet.NewTCPConn(&wq, endpoint)
|
||||
lAddr := tcpConn.RemoteAddr()
|
||||
rAddr := tcpConn.LocalAddr()
|
||||
if lAddr == nil || rAddr == nil {
|
||||
tcpConn.Close()
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
var metadata M.Metadata
|
||||
metadata.Source = M.SocksaddrFromNet(lAddr)
|
||||
metadata.Destination = M.SocksaddrFromNet(rAddr)
|
||||
hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata)
|
||||
if hErr != nil {
|
||||
endpoint.Abort()
|
||||
}
|
||||
}()
|
||||
})
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
if !t.endpointIndependentNat {
|
||||
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
endpoint, err := request.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
|
||||
lAddr := udpConn.RemoteAddr()
|
||||
rAddr := udpConn.LocalAddr()
|
||||
if lAddr == nil || rAddr == nil {
|
||||
endpoint.Abort()
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
var metadata M.Metadata
|
||||
metadata.Source = M.SocksaddrFromNet(lAddr)
|
||||
metadata.Destination = M.SocksaddrFromNet(rAddr)
|
||||
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(&gUDPConn{udpConn}), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
|
||||
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
|
||||
if hErr != nil {
|
||||
endpoint.Abort()
|
||||
}
|
||||
}()
|
||||
})
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
} else {
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
|
||||
}
|
||||
|
||||
t.stack = ipStack
|
||||
t.endpoint = linkEndpoint
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *GVisor) Close() error {
|
||||
t.endpoint.Attach(nil)
|
||||
t.stack.Close()
|
||||
for _, endpoint := range t.stack.CleanupEndpoints() {
|
||||
endpoint.Abort()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddressFromAddr(destination netip.Addr) tcpip.Address {
|
||||
if destination.Is6() {
|
||||
return tcpip.AddrFrom16(destination.As16())
|
||||
} else {
|
||||
return tcpip.AddrFrom4(destination.As4())
|
||||
}
|
||||
}
|
||||
|
||||
func AddrFromAddress(address tcpip.Address) netip.Addr {
|
||||
if address.Len() == 16 {
|
||||
return netip.AddrFrom16(address.As16())
|
||||
} else {
|
||||
return netip.AddrFrom4(address.As4())
|
||||
}
|
||||
}
|
||||
72
gvisor_err.go
Normal file
72
gvisor_err.go
Normal file
@@ -0,0 +1,72 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type gTCPConn struct {
|
||||
*gonet.TCPConn
|
||||
}
|
||||
|
||||
func (c *gTCPConn) Upstream() any {
|
||||
return c.TCPConn
|
||||
}
|
||||
|
||||
func (c *gTCPConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.TCPConn.Write(b)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
type gUDPConn struct {
|
||||
*gonet.UDPConn
|
||||
}
|
||||
|
||||
func (c *gUDPConn) Read(b []byte) (n int, err error) {
|
||||
n, err = c.UDPConn.Read(b)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *gUDPConn) Write(b []byte) (n int, err error) {
|
||||
n, err = c.UDPConn.Write(b)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = wrapError(err)
|
||||
return
|
||||
}
|
||||
|
||||
func wrapStackError(err tcpip.Error) error {
|
||||
switch err.(type) {
|
||||
case *tcpip.ErrClosedForSend,
|
||||
*tcpip.ErrClosedForReceive,
|
||||
*tcpip.ErrAborted:
|
||||
return net.ErrClosed
|
||||
}
|
||||
return E.New(err.String())
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
if opErr, isOpErr := err.(*net.OpError); isOpErr {
|
||||
switch opErr.Err.Error() {
|
||||
case "endpoint is closed for send",
|
||||
"endpoint is closed for receive",
|
||||
"operation aborted":
|
||||
return net.ErrClosed
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -13,9 +13,3 @@ func NewGVisor(
|
||||
) (Stack, error) {
|
||||
return nil, ErrGVisorNotIncluded
|
||||
}
|
||||
|
||||
func NewMixed(
|
||||
options StackOptions,
|
||||
) (Stack, error) {
|
||||
return nil, ErrGVisorNotIncluded
|
||||
}
|
||||
125
gvisor_udp.go
Normal file
125
gvisor_udp.go
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
)
|
||||
|
||||
type UDPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
udpNat *udpnat.Service[netip.AddrPort]
|
||||
}
|
||||
|
||||
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
|
||||
return &UDPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
var upstreamMetadata M.Metadata
|
||||
upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
|
||||
upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
|
||||
var netProto tcpip.NetworkProtocolNumber
|
||||
if upstreamMetadata.Source.IsIPv4() {
|
||||
netProto = header.IPv4ProtocolNumber
|
||||
} else {
|
||||
netProto = header.IPv6ProtocolNumber
|
||||
}
|
||||
f.udpNat.NewPacket(
|
||||
f.ctx,
|
||||
upstreamMetadata.Source.AddrPort(),
|
||||
buf.As(pkt.Data().AsRange().ToSlice()),
|
||||
upstreamMetadata,
|
||||
func(natConn N.PacketConn) N.PacketWriter {
|
||||
return &UDPBackWriter{f.stack, id.RemoteAddress, id.RemotePort, netProto}
|
||||
},
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
type UDPBackWriter struct {
|
||||
stack *stack.Stack
|
||||
source tcpip.Address
|
||||
sourcePort uint16
|
||||
sourceNetwork tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
|
||||
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
|
||||
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) {
|
||||
return E.New("send IPv6 packet to IPv4 connection")
|
||||
}
|
||||
|
||||
defer packetBuffer.Release()
|
||||
|
||||
route, err := w.stack.FindRoute(
|
||||
defaultNIC,
|
||||
AddressFromAddr(destination.Addr),
|
||||
w.source,
|
||||
w.sourceNetwork,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return wrapStackError(err)
|
||||
}
|
||||
defer route.Release()
|
||||
|
||||
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
|
||||
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
|
||||
})
|
||||
defer packet.DecRef()
|
||||
|
||||
packet.TransportProtocolNumber = header.UDPProtocolNumber
|
||||
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
|
||||
pLen := uint16(packet.Size())
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: destination.Port,
|
||||
DstPort: w.sourcePort,
|
||||
Length: pLen,
|
||||
})
|
||||
|
||||
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
|
||||
xsum := udpHdr.CalculateChecksum(checksum.Combine(
|
||||
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
|
||||
packet.Data().Checksum(),
|
||||
))
|
||||
if xsum != math.MaxUint16 {
|
||||
xsum = ^xsum
|
||||
}
|
||||
udpHdr.SetChecksum(xsum)
|
||||
}
|
||||
|
||||
err = route.WritePacket(stack.NetworkHeaderParams{
|
||||
Protocol: header.UDPProtocolNumber,
|
||||
TTL: route.DefaultTTL(),
|
||||
TOS: 0,
|
||||
}, packet)
|
||||
|
||||
if err != nil {
|
||||
route.Stats().UDP.PacketSendErrors.Increment()
|
||||
return wrapStackError(err)
|
||||
}
|
||||
|
||||
route.Stats().UDP.PacketsSent.Increment()
|
||||
return nil
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package checksum_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
"github.com/sagernet/sing-tun/internal/tschecksum"
|
||||
)
|
||||
|
||||
func BenchmarkTsChecksum(b *testing.B) {
|
||||
packet := make([][]byte, 1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
packet[i] = make([]byte, 1500)
|
||||
rand.Read(packet[i])
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tschecksum.Checksum(packet[i%1000], 0)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGChecksum(b *testing.B) {
|
||||
packet := make([][]byte, 1000)
|
||||
for i := 0; i < 1000; i++ {
|
||||
packet[i] = make([]byte, 1500)
|
||||
rand.Read(packet[i])
|
||||
}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
checksum.ChecksumDefault(packet[i%1000], 0)
|
||||
}
|
||||
}
|
||||
40
internal/clashtcpip/icmp.go
Normal file
40
internal/clashtcpip/icmp.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type ICMPType = byte
|
||||
|
||||
const (
|
||||
ICMPTypePingRequest byte = 0x8
|
||||
ICMPTypePingResponse byte = 0x0
|
||||
)
|
||||
|
||||
type ICMPPacket []byte
|
||||
|
||||
func (p ICMPPacket) Type() ICMPType {
|
||||
return p[0]
|
||||
}
|
||||
|
||||
func (p ICMPPacket) SetType(v ICMPType) {
|
||||
p[0] = v
|
||||
}
|
||||
|
||||
func (p ICMPPacket) Code() byte {
|
||||
return p[1]
|
||||
}
|
||||
|
||||
func (p ICMPPacket) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
func (p ICMPPacket) SetChecksum(sum [2]byte) {
|
||||
p[2] = sum[0]
|
||||
p[3] = sum[1]
|
||||
}
|
||||
|
||||
func (p ICMPPacket) ResetChecksum() {
|
||||
p.SetChecksum(zeroChecksum)
|
||||
p.SetChecksum(Checksum(0, p))
|
||||
}
|
||||
172
internal/clashtcpip/icmpv6.go
Normal file
172
internal/clashtcpip/icmpv6.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type ICMPv6Packet []byte
|
||||
|
||||
const (
|
||||
ICMPv6HeaderSize = 4
|
||||
|
||||
ICMPv6MinimumSize = 8
|
||||
|
||||
ICMPv6PayloadOffset = 8
|
||||
|
||||
ICMPv6EchoMinimumSize = 8
|
||||
|
||||
ICMPv6ErrorHeaderSize = 8
|
||||
|
||||
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
|
||||
|
||||
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
|
||||
|
||||
ICMPv6ChecksumOffset = 2
|
||||
|
||||
icmpv6PointerOffset = 4
|
||||
|
||||
icmpv6MTUOffset = 4
|
||||
|
||||
icmpv6IdentOffset = 4
|
||||
|
||||
icmpv6SequenceOffset = 6
|
||||
|
||||
NDPHopLimit = 255
|
||||
)
|
||||
|
||||
type ICMPv6Type byte
|
||||
|
||||
const (
|
||||
ICMPv6DstUnreachable ICMPv6Type = 1
|
||||
ICMPv6PacketTooBig ICMPv6Type = 2
|
||||
ICMPv6TimeExceeded ICMPv6Type = 3
|
||||
ICMPv6ParamProblem ICMPv6Type = 4
|
||||
ICMPv6EchoRequest ICMPv6Type = 128
|
||||
ICMPv6EchoReply ICMPv6Type = 129
|
||||
|
||||
ICMPv6RouterSolicit ICMPv6Type = 133
|
||||
ICMPv6RouterAdvert ICMPv6Type = 134
|
||||
ICMPv6NeighborSolicit ICMPv6Type = 135
|
||||
ICMPv6NeighborAdvert ICMPv6Type = 136
|
||||
ICMPv6RedirectMsg ICMPv6Type = 137
|
||||
|
||||
ICMPv6MulticastListenerQuery ICMPv6Type = 130
|
||||
ICMPv6MulticastListenerReport ICMPv6Type = 131
|
||||
ICMPv6MulticastListenerDone ICMPv6Type = 132
|
||||
)
|
||||
|
||||
func (typ ICMPv6Type) IsErrorType() bool {
|
||||
return typ&0x80 == 0
|
||||
}
|
||||
|
||||
type ICMPv6Code byte
|
||||
|
||||
const (
|
||||
ICMPv6NetworkUnreachable ICMPv6Code = 0
|
||||
ICMPv6Prohibited ICMPv6Code = 1
|
||||
ICMPv6BeyondScope ICMPv6Code = 2
|
||||
ICMPv6AddressUnreachable ICMPv6Code = 3
|
||||
ICMPv6PortUnreachable ICMPv6Code = 4
|
||||
ICMPv6Policy ICMPv6Code = 5
|
||||
ICMPv6RejectRoute ICMPv6Code = 6
|
||||
)
|
||||
|
||||
const (
|
||||
ICMPv6HopLimitExceeded ICMPv6Code = 0
|
||||
ICMPv6ReassemblyTimeout ICMPv6Code = 1
|
||||
)
|
||||
|
||||
const (
|
||||
ICMPv6ErroneousHeader ICMPv6Code = 0
|
||||
|
||||
ICMPv6UnknownHeader ICMPv6Code = 1
|
||||
|
||||
ICMPv6UnknownOption ICMPv6Code = 2
|
||||
)
|
||||
|
||||
const ICMPv6UnusedCode ICMPv6Code = 0
|
||||
|
||||
func (b ICMPv6Packet) Type() ICMPv6Type {
|
||||
return ICMPv6Type(b[0])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetType(t ICMPv6Type) {
|
||||
b[0] = byte(t)
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) Code() ICMPv6Code {
|
||||
return ICMPv6Code(b[1])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetCode(c ICMPv6Code) {
|
||||
b[1] = byte(c)
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) TypeSpecific() uint32 {
|
||||
return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetTypeSpecific(val uint32) {
|
||||
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetChecksum(sum [2]byte) {
|
||||
_ = b[ICMPv6ChecksumOffset+1]
|
||||
b[ICMPv6ChecksumOffset] = sum[0]
|
||||
b[ICMPv6ChecksumOffset+1] = sum[1]
|
||||
}
|
||||
|
||||
func (ICMPv6Packet) SourcePort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ICMPv6Packet) DestinationPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (ICMPv6Packet) SetSourcePort(uint16) {
|
||||
}
|
||||
|
||||
func (ICMPv6Packet) SetDestinationPort(uint16) {
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) MTU() uint32 {
|
||||
return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetMTU(mtu uint32) {
|
||||
binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) Ident() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetIdent(ident uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) Sequence() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) SetSequence(sequence uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) MessageBody() []byte {
|
||||
return b[ICMPv6HeaderSize:]
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) Payload() []byte {
|
||||
return b[ICMPv6PayloadOffset:]
|
||||
}
|
||||
|
||||
func (b ICMPv6Packet) ResetChecksum(psum uint32) {
|
||||
b.SetChecksum(zeroChecksum)
|
||||
b.SetChecksum(Checksum(psum, b))
|
||||
}
|
||||
209
internal/clashtcpip/ip.go
Normal file
209
internal/clashtcpip/ip.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
type IPProtocol = byte
|
||||
|
||||
type IP interface {
|
||||
Payload() []byte
|
||||
SourceIP() netip.Addr
|
||||
DestinationIP() netip.Addr
|
||||
SetSourceIP(ip netip.Addr)
|
||||
SetDestinationIP(ip netip.Addr)
|
||||
Protocol() IPProtocol
|
||||
DecTimeToLive()
|
||||
ResetChecksum()
|
||||
PseudoSum() uint32
|
||||
}
|
||||
|
||||
// IPProtocol type
|
||||
const (
|
||||
ICMP IPProtocol = 0x01
|
||||
TCP IPProtocol = 0x06
|
||||
UDP IPProtocol = 0x11
|
||||
ICMPv6 IPProtocol = 0x3a
|
||||
)
|
||||
|
||||
const (
|
||||
FlagDontFragment = 1 << 1
|
||||
FlagMoreFragment = 1 << 2
|
||||
)
|
||||
|
||||
const (
|
||||
IPv4HeaderSize = 20
|
||||
|
||||
IPv4Version = 4
|
||||
|
||||
IPv4OptionsOffset = 20
|
||||
IPv4PacketMinLength = IPv4OptionsOffset
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidLength = errors.New("invalid packet length")
|
||||
ErrInvalidIPVersion = errors.New("invalid ip version")
|
||||
ErrInvalidChecksum = errors.New("invalid checksum")
|
||||
)
|
||||
|
||||
type IPv4Packet []byte
|
||||
|
||||
func (p IPv4Packet) TotalLen() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetTotalLength(length uint16) {
|
||||
binary.BigEndian.PutUint16(p[2:], length)
|
||||
}
|
||||
|
||||
func (p IPv4Packet) HeaderLen() uint16 {
|
||||
return uint16(p[0]&0xf) * 4
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetHeaderLen(length uint16) {
|
||||
p[0] &= 0xF0
|
||||
p[0] |= byte(length / 4)
|
||||
}
|
||||
|
||||
func (p IPv4Packet) TypeOfService() byte {
|
||||
return p[1]
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetTypeOfService(tos byte) {
|
||||
p[1] = tos
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Identification() uint16 {
|
||||
return binary.BigEndian.Uint16(p[4:])
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetIdentification(id uint16) {
|
||||
binary.BigEndian.PutUint16(p[4:], id)
|
||||
}
|
||||
|
||||
func (p IPv4Packet) FragmentOffset() uint16 {
|
||||
return binary.BigEndian.Uint16([]byte{p[6] & 0x7, p[7]}) * 8
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetFragmentOffset(offset uint32) {
|
||||
flags := p.Flags()
|
||||
binary.BigEndian.PutUint16(p[6:], uint16(offset/8))
|
||||
p.SetFlags(flags)
|
||||
}
|
||||
|
||||
func (p IPv4Packet) DataLen() uint16 {
|
||||
return p.TotalLen() - p.HeaderLen()
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Payload() []byte {
|
||||
return p[p.HeaderLen():p.TotalLen()]
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Protocol() IPProtocol {
|
||||
return p[9]
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetProtocol(protocol IPProtocol) {
|
||||
p[9] = protocol
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Flags() byte {
|
||||
return p[6] >> 5
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetFlags(flags byte) {
|
||||
p[6] &= 0x1F
|
||||
p[6] |= flags << 5
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SourceIP() netip.Addr {
|
||||
return netip.AddrFrom4([4]byte{p[12], p[13], p[14], p[15]})
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetSourceIP(ip netip.Addr) {
|
||||
if ip.Is4() {
|
||||
copy(p[12:16], ip.AsSlice())
|
||||
}
|
||||
}
|
||||
|
||||
func (p IPv4Packet) DestinationIP() netip.Addr {
|
||||
return netip.AddrFrom4([4]byte{p[16], p[17], p[18], p[19]})
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetDestinationIP(ip netip.Addr) {
|
||||
if ip.Is4() {
|
||||
copy(p[16:20], ip.AsSlice())
|
||||
}
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(p[10:])
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetChecksum(sum [2]byte) {
|
||||
p[10] = sum[0]
|
||||
p[11] = sum[1]
|
||||
}
|
||||
|
||||
func (p IPv4Packet) TimeToLive() uint8 {
|
||||
return p[8]
|
||||
}
|
||||
|
||||
func (p IPv4Packet) SetTimeToLive(ttl uint8) {
|
||||
p[8] = ttl
|
||||
}
|
||||
|
||||
func (p IPv4Packet) DecTimeToLive() {
|
||||
p[8] = p[8] - uint8(1)
|
||||
}
|
||||
|
||||
func (p IPv4Packet) ResetChecksum() {
|
||||
p.SetChecksum(zeroChecksum)
|
||||
p.SetChecksum(Checksum(0, p[:p.HeaderLen()]))
|
||||
}
|
||||
|
||||
// PseudoSum for tcp checksum
|
||||
func (p IPv4Packet) PseudoSum() uint32 {
|
||||
sum := Sum(p[12:20])
|
||||
sum += uint32(p.Protocol())
|
||||
sum += uint32(p.DataLen())
|
||||
return sum
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Valid() bool {
|
||||
return len(p) >= IPv4HeaderSize && p.TotalLen() >= p.HeaderLen() && uint16(len(p)) >= p.TotalLen()
|
||||
}
|
||||
|
||||
func (p IPv4Packet) Verify() error {
|
||||
if len(p) < IPv4PacketMinLength {
|
||||
return ErrInvalidLength
|
||||
}
|
||||
|
||||
checksum := []byte{p[10], p[11]}
|
||||
headerLength := uint16(p[0]&0xF) * 4
|
||||
packetLength := binary.BigEndian.Uint16(p[2:])
|
||||
|
||||
if p[0]>>4 != 4 {
|
||||
return ErrInvalidIPVersion
|
||||
}
|
||||
|
||||
if uint16(len(p)) < packetLength || packetLength < headerLength {
|
||||
return ErrInvalidLength
|
||||
}
|
||||
|
||||
p[10] = 0
|
||||
p[11] = 0
|
||||
defer copy(p[10:12], checksum)
|
||||
|
||||
answer := Checksum(0, p[:headerLength])
|
||||
|
||||
if answer[0] != checksum[0] || answer[1] != checksum[1] {
|
||||
return ErrInvalidChecksum
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ IP = (*IPv4Packet)(nil)
|
||||
141
internal/clashtcpip/ipv6.go
Normal file
141
internal/clashtcpip/ipv6.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
const (
|
||||
versTCFL = 0
|
||||
|
||||
IPv6PayloadLenOffset = 4
|
||||
|
||||
IPv6NextHeaderOffset = 6
|
||||
hopLimit = 7
|
||||
v6SrcAddr = 8
|
||||
v6DstAddr = v6SrcAddr + IPv6AddressSize
|
||||
|
||||
IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
|
||||
)
|
||||
|
||||
const (
|
||||
versIHL = 0
|
||||
tos = 1
|
||||
ipVersionShift = 4
|
||||
ipIHLMask = 0x0f
|
||||
IPv4IHLStride = 4
|
||||
)
|
||||
|
||||
type IPv6Packet []byte
|
||||
|
||||
const (
|
||||
IPv6MinimumSize = IPv6FixedHeaderSize
|
||||
|
||||
IPv6AddressSize = 16
|
||||
|
||||
IPv6Version = 6
|
||||
|
||||
IPv6MinimumMTU = 1280
|
||||
)
|
||||
|
||||
func (b IPv6Packet) PayloadLength() uint16 {
|
||||
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
|
||||
}
|
||||
|
||||
func (b IPv6Packet) HopLimit() uint8 {
|
||||
return b[hopLimit]
|
||||
}
|
||||
|
||||
func (b IPv6Packet) NextHeader() byte {
|
||||
return b[IPv6NextHeaderOffset]
|
||||
}
|
||||
|
||||
func (b IPv6Packet) Protocol() IPProtocol {
|
||||
return b.NextHeader()
|
||||
}
|
||||
|
||||
func (b IPv6Packet) Payload() []byte {
|
||||
return b[IPv6MinimumSize:][:b.PayloadLength()]
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SourceIP() netip.Addr {
|
||||
addr, _ := netip.AddrFromSlice(b[v6SrcAddr:][:IPv6AddressSize])
|
||||
return addr
|
||||
}
|
||||
|
||||
func (b IPv6Packet) DestinationIP() netip.Addr {
|
||||
addr, _ := netip.AddrFromSlice(b[v6DstAddr:][:IPv6AddressSize])
|
||||
return addr
|
||||
}
|
||||
|
||||
func (IPv6Packet) Checksum() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (b IPv6Packet) TOS() (uint8, uint32) {
|
||||
v := binary.BigEndian.Uint32(b[versTCFL:])
|
||||
return uint8(v >> 20), v & 0xfffff
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetTOS(t uint8, l uint32) {
|
||||
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
|
||||
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetPayloadLength(payloadLength uint16) {
|
||||
binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetSourceIP(addr netip.Addr) {
|
||||
if addr.Is6() {
|
||||
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetDestinationIP(addr netip.Addr) {
|
||||
if addr.Is6() {
|
||||
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetHopLimit(v uint8) {
|
||||
b[hopLimit] = v
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetNextHeader(v byte) {
|
||||
b[IPv6NextHeaderOffset] = v
|
||||
}
|
||||
|
||||
func (b IPv6Packet) SetProtocol(p IPProtocol) {
|
||||
b.SetNextHeader(p)
|
||||
}
|
||||
|
||||
func (b IPv6Packet) DecTimeToLive() {
|
||||
b[hopLimit] = b[hopLimit] - uint8(1)
|
||||
}
|
||||
|
||||
func (IPv6Packet) SetChecksum(uint16) {
|
||||
}
|
||||
|
||||
func (IPv6Packet) ResetChecksum() {
|
||||
}
|
||||
|
||||
func (b IPv6Packet) PseudoSum() uint32 {
|
||||
sum := Sum(b[v6SrcAddr:IPv6FixedHeaderSize])
|
||||
sum += uint32(b.Protocol())
|
||||
sum += uint32(b.PayloadLength())
|
||||
return sum
|
||||
}
|
||||
|
||||
func (b IPv6Packet) Valid() bool {
|
||||
return len(b) >= IPv6MinimumSize && len(b) >= int(b.PayloadLength())+IPv6MinimumSize
|
||||
}
|
||||
|
||||
func IPVersion(b []byte) int {
|
||||
if len(b) < versIHL+1 {
|
||||
return -1
|
||||
}
|
||||
return int(b[versIHL] >> ipVersionShift)
|
||||
}
|
||||
|
||||
var _ IP = (*IPv6Packet)(nil)
|
||||
90
internal/clashtcpip/tcp.go
Normal file
90
internal/clashtcpip/tcp.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
TCPFin uint16 = 1 << 0
|
||||
TCPSyn uint16 = 1 << 1
|
||||
TCPRst uint16 = 1 << 2
|
||||
TCPPuh uint16 = 1 << 3
|
||||
TCPAck uint16 = 1 << 4
|
||||
TCPUrg uint16 = 1 << 5
|
||||
TCPEce uint16 = 1 << 6
|
||||
TCPEwr uint16 = 1 << 7
|
||||
TCPNs uint16 = 1 << 8
|
||||
)
|
||||
|
||||
const TCPHeaderSize = 20
|
||||
|
||||
type TCPPacket []byte
|
||||
|
||||
func (p TCPPacket) SourcePort() uint16 {
|
||||
return binary.BigEndian.Uint16(p)
|
||||
}
|
||||
|
||||
func (p TCPPacket) SetSourcePort(port uint16) {
|
||||
binary.BigEndian.PutUint16(p, port)
|
||||
}
|
||||
|
||||
func (p TCPPacket) DestinationPort() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
func (p TCPPacket) SetDestinationPort(port uint16) {
|
||||
binary.BigEndian.PutUint16(p[2:], port)
|
||||
}
|
||||
|
||||
func (p TCPPacket) Flags() uint16 {
|
||||
return uint16(p[13] | (p[12] & 0x1))
|
||||
}
|
||||
|
||||
func (p TCPPacket) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(p[16:])
|
||||
}
|
||||
|
||||
func (p TCPPacket) SetChecksum(sum [2]byte) {
|
||||
p[16] = sum[0]
|
||||
p[17] = sum[1]
|
||||
}
|
||||
|
||||
func (p TCPPacket) ResetChecksum(psum uint32) {
|
||||
p.SetChecksum(zeroChecksum)
|
||||
p.SetChecksum(Checksum(psum, p))
|
||||
}
|
||||
|
||||
func (p TCPPacket) Valid() bool {
|
||||
return len(p) >= TCPHeaderSize
|
||||
}
|
||||
|
||||
func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error {
|
||||
var checksum [2]byte
|
||||
checksum[0] = p[16]
|
||||
checksum[1] = p[17]
|
||||
|
||||
// reset checksum
|
||||
p[16] = 0
|
||||
p[17] = 0
|
||||
|
||||
// restore checksum
|
||||
defer func() {
|
||||
p[16] = checksum[0]
|
||||
p[17] = checksum[1]
|
||||
}()
|
||||
|
||||
// check checksum
|
||||
s := uint32(0)
|
||||
s += Sum(sourceAddress)
|
||||
s += Sum(targetAddress)
|
||||
s += uint32(TCP)
|
||||
s += uint32(len(p))
|
||||
|
||||
check := Checksum(s, p)
|
||||
if checksum[0] != check[0] || checksum[1] != check[1] {
|
||||
return ErrInvalidChecksum
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
24
internal/clashtcpip/tcpip.go
Normal file
24
internal/clashtcpip/tcpip.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package clashtcpip
|
||||
|
||||
var zeroChecksum = [2]byte{0x00, 0x00}
|
||||
|
||||
var SumFnc = SumCompat
|
||||
|
||||
func Sum(b []byte) uint32 {
|
||||
return SumFnc(b)
|
||||
}
|
||||
|
||||
// Checksum for Internet Protocol family headers
|
||||
func Checksum(sum uint32, b []byte) (answer [2]byte) {
|
||||
sum += Sum(b)
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
sum += sum >> 16
|
||||
sum = ^sum
|
||||
answer[0] = byte(sum >> 8)
|
||||
answer[1] = byte(sum)
|
||||
return
|
||||
}
|
||||
|
||||
func SetIPv4(packet []byte) {
|
||||
packet[0] = (packet[0] & 0x0f) | (4 << 4)
|
||||
}
|
||||
26
internal/clashtcpip/tcpip_amd64.go
Normal file
26
internal/clashtcpip/tcpip_amd64.go
Normal file
@@ -0,0 +1,26 @@
|
||||
//go:build !noasm
|
||||
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
//go:noescape
|
||||
func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr
|
||||
|
||||
func SumAVX2(data []byte) uint32 {
|
||||
if len(data) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return uint32(sumAsmAvx2(unsafe.Pointer(&data[0]), uintptr(len(data))))
|
||||
}
|
||||
|
||||
func init() {
|
||||
if cpu.X86.HasAVX2 {
|
||||
SumFnc = SumAVX2
|
||||
}
|
||||
}
|
||||
140
internal/clashtcpip/tcpip_amd64.s
Normal file
140
internal/clashtcpip/tcpip_amd64.s
Normal file
@@ -0,0 +1,140 @@
|
||||
#include "textflag.h"
|
||||
|
||||
DATA endian_swap_mask<>+0(SB)/8, $0x607040502030001
|
||||
DATA endian_swap_mask<>+8(SB)/8, $0xE0F0C0D0A0B0809
|
||||
DATA endian_swap_mask<>+16(SB)/8, $0x607040502030001
|
||||
DATA endian_swap_mask<>+24(SB)/8, $0xE0F0C0D0A0B0809
|
||||
GLOBL endian_swap_mask<>(SB), RODATA, $32
|
||||
|
||||
// func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr
|
||||
//
|
||||
// args (8 bytes aligned):
|
||||
// data unsafe.Pointer - 8 bytes - 0 offset
|
||||
// length uintptr - 8 bytes - 8 offset
|
||||
// result uintptr - 8 bytes - 16 offset
|
||||
#define PDATA AX
|
||||
#define LENGTH CX
|
||||
#define RESULT BX
|
||||
TEXT ·sumAsmAvx2(SB),NOSPLIT,$0-24
|
||||
MOVQ data+0(FP), PDATA
|
||||
MOVQ length+8(FP), LENGTH
|
||||
XORQ RESULT, RESULT
|
||||
|
||||
#define VSUM Y0
|
||||
#define ENDIAN_SWAP_MASK Y1
|
||||
BEGIN:
|
||||
VMOVDQU endian_swap_mask<>(SB), ENDIAN_SWAP_MASK
|
||||
VPXOR VSUM, VSUM, VSUM
|
||||
|
||||
#define LOADED_0 Y2
|
||||
#define LOADED_1 Y3
|
||||
#define LOADED_2 Y4
|
||||
#define LOADED_3 Y5
|
||||
BATCH_64:
|
||||
CMPQ LENGTH, $64
|
||||
JB BATCH_32
|
||||
VPMOVZXWD (PDATA), LOADED_0
|
||||
VPMOVZXWD 16(PDATA), LOADED_1
|
||||
VPMOVZXWD 32(PDATA), LOADED_2
|
||||
VPMOVZXWD 48(PDATA), LOADED_3
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED_2, LOADED_2
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED_3, LOADED_3
|
||||
VPADDD LOADED_0, VSUM, VSUM
|
||||
VPADDD LOADED_1, VSUM, VSUM
|
||||
VPADDD LOADED_2, VSUM, VSUM
|
||||
VPADDD LOADED_3, VSUM, VSUM
|
||||
ADDQ $-64, LENGTH
|
||||
ADDQ $64, PDATA
|
||||
JMP BATCH_64
|
||||
#undef LOADED_0
|
||||
#undef LOADED_1
|
||||
#undef LOADED_2
|
||||
#undef LOADED_3
|
||||
|
||||
#define LOADED_0 Y2
|
||||
#define LOADED_1 Y3
|
||||
BATCH_32:
|
||||
CMPQ LENGTH, $32
|
||||
JB BATCH_16
|
||||
VPMOVZXWD (PDATA), LOADED_0
|
||||
VPMOVZXWD 16(PDATA), LOADED_1
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1
|
||||
VPADDD LOADED_0, VSUM, VSUM
|
||||
VPADDD LOADED_1, VSUM, VSUM
|
||||
ADDQ $-32, LENGTH
|
||||
ADDQ $32, PDATA
|
||||
JMP BATCH_32
|
||||
#undef LOADED_0
|
||||
#undef LOADED_1
|
||||
|
||||
#define LOADED Y2
|
||||
BATCH_16:
|
||||
CMPQ LENGTH, $16
|
||||
JB COLLECT
|
||||
VPMOVZXWD (PDATA), LOADED
|
||||
VPSHUFB ENDIAN_SWAP_MASK, LOADED, LOADED
|
||||
VPADDD LOADED, VSUM, VSUM
|
||||
ADDQ $-16, LENGTH
|
||||
ADDQ $16, PDATA
|
||||
JMP BATCH_16
|
||||
#undef LOADED
|
||||
|
||||
#define EXTRACTED Y2
|
||||
#define EXTRACTED_128 X2
|
||||
#define TEMP_64 DX
|
||||
COLLECT:
|
||||
VEXTRACTI128 $0, VSUM, EXTRACTED_128
|
||||
VPEXTRD $0, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VPEXTRD $1, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VPEXTRD $2, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VPEXTRD $3, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VEXTRACTI128 $1, VSUM, EXTRACTED_128
|
||||
VPEXTRD $0, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VPEXTRD $1, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VPEXTRD $2, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
VPEXTRD $3, EXTRACTED_128, TEMP_64
|
||||
ADDL TEMP_64, RESULT
|
||||
#undef EXTRACTED
|
||||
#undef EXTRACTED_128
|
||||
#undef TEMP_64
|
||||
|
||||
#define TEMP DX
|
||||
#define TEMP2 SI
|
||||
BATCH_2:
|
||||
CMPQ LENGTH, $2
|
||||
JB BATCH_1
|
||||
XORQ TEMP, TEMP
|
||||
MOVW (PDATA), TEMP
|
||||
MOVQ TEMP, TEMP2
|
||||
SHRW $8, TEMP2
|
||||
SHLW $8, TEMP
|
||||
ORW TEMP2, TEMP
|
||||
ADDL TEMP, RESULT
|
||||
ADDQ $-2, LENGTH
|
||||
ADDQ $2, PDATA
|
||||
JMP BATCH_2
|
||||
#undef TEMP
|
||||
|
||||
#define TEMP DX
|
||||
BATCH_1:
|
||||
CMPQ LENGTH, $0
|
||||
JZ RETURN
|
||||
XORQ TEMP, TEMP
|
||||
MOVB (PDATA), TEMP
|
||||
SHLW $8, TEMP
|
||||
ADDL TEMP, RESULT
|
||||
#undef TEMP
|
||||
|
||||
RETURN:
|
||||
MOVQ RESULT, result+16(FP)
|
||||
RET
|
||||
51
internal/clashtcpip/tcpip_amd64_test.go
Normal file
51
internal/clashtcpip/tcpip_amd64_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
func Test_SumAVX2(t *testing.T) {
|
||||
if !cpu.X86.HasAVX2 {
|
||||
t.Skipf("AVX2 unavailable")
|
||||
}
|
||||
|
||||
bytes := make([]byte, chunkSize)
|
||||
|
||||
for size := 0; size <= chunkSize; size++ {
|
||||
for count := 0; count < chunkCount; count++ {
|
||||
_, err := rand.Reader.Read(bytes[:size])
|
||||
if err != nil {
|
||||
t.Skipf("Rand read failed: %v", err)
|
||||
}
|
||||
|
||||
compat := SumCompat(bytes[:size])
|
||||
avx := SumAVX2(bytes[:size])
|
||||
|
||||
if compat != avx {
|
||||
t.Errorf("Sum of length=%d mismatched", size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_SumAVX2(b *testing.B) {
|
||||
if !cpu.X86.HasAVX2 {
|
||||
b.Skipf("AVX2 unavailable")
|
||||
}
|
||||
|
||||
bytes := make([]byte, chunkSize)
|
||||
|
||||
_, err := rand.Reader.Read(bytes)
|
||||
if err != nil {
|
||||
b.Skipf("Rand read failed: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
SumAVX2(bytes)
|
||||
}
|
||||
}
|
||||
24
internal/clashtcpip/tcpip_arm64.go
Normal file
24
internal/clashtcpip/tcpip_arm64.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
//go:noescape
|
||||
func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr
|
||||
|
||||
func SumNeon(data []byte) uint32 {
|
||||
if len(data) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return uint32(sumAsmNeon(unsafe.Pointer(&data[0]), uintptr(len(data))))
|
||||
}
|
||||
|
||||
func init() {
|
||||
if cpu.ARM64.HasASIMD {
|
||||
SumFnc = SumNeon
|
||||
}
|
||||
}
|
||||
118
internal/clashtcpip/tcpip_arm64.s
Normal file
118
internal/clashtcpip/tcpip_arm64.s
Normal file
@@ -0,0 +1,118 @@
|
||||
#include "textflag.h"
|
||||
|
||||
// func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr
|
||||
//
|
||||
// args (8 bytes aligned):
|
||||
// data unsafe.Pointer - 8 bytes - 0 offset
|
||||
// length uintptr - 8 bytes - 8 offset
|
||||
// result uintptr - 8 bytes - 16 offset
|
||||
#define PDATA R0
|
||||
#define LENGTH R1
|
||||
#define RESULT R2
|
||||
#define VSUM V0
|
||||
TEXT ·sumAsmNeon(SB),NOSPLIT,$0-24
|
||||
MOVD data+0(FP), PDATA
|
||||
MOVD length+8(FP), LENGTH
|
||||
MOVD $0, RESULT
|
||||
VMOVQ $0, $0, VSUM
|
||||
|
||||
#define LOADED_0 V1
|
||||
#define LOADED_1 V2
|
||||
#define LOADED_2 V3
|
||||
#define LOADED_3 V4
|
||||
BATCH_32:
|
||||
CMP $32, LENGTH
|
||||
BLO BATCH_16
|
||||
VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8, LOADED_2.B8, LOADED_3.B8]
|
||||
VREV16 LOADED_0.B8, LOADED_0.B8
|
||||
VREV16 LOADED_1.B8, LOADED_1.B8
|
||||
VREV16 LOADED_2.B8, LOADED_2.B8
|
||||
VREV16 LOADED_3.B8, LOADED_3.B8
|
||||
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
|
||||
VUSHLL $0, LOADED_1.H4, LOADED_1.S4
|
||||
VUSHLL $0, LOADED_2.H4, LOADED_2.S4
|
||||
VUSHLL $0, LOADED_3.H4, LOADED_3.S4
|
||||
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
|
||||
VADD LOADED_1.S4, VSUM.S4, VSUM.S4
|
||||
VADD LOADED_2.S4, VSUM.S4, VSUM.S4
|
||||
VADD LOADED_3.S4, VSUM.S4, VSUM.S4
|
||||
ADD $-32, LENGTH
|
||||
ADD $32, PDATA
|
||||
B BATCH_32
|
||||
#undef LOADED_0
|
||||
#undef LOADED_1
|
||||
#undef LOADED_2
|
||||
#undef LOADED_3
|
||||
|
||||
#define LOADED_0 V1
|
||||
#define LOADED_1 V2
|
||||
BATCH_16:
|
||||
CMP $16, LENGTH
|
||||
BLO BATCH_8
|
||||
VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8]
|
||||
VREV16 LOADED_0.B8, LOADED_0.B8
|
||||
VREV16 LOADED_1.B8, LOADED_1.B8
|
||||
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
|
||||
VUSHLL $0, LOADED_1.H4, LOADED_1.S4
|
||||
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
|
||||
VADD LOADED_1.S4, VSUM.S4, VSUM.S4
|
||||
ADD $-16, LENGTH
|
||||
ADD $16, PDATA
|
||||
B BATCH_16
|
||||
#undef LOADED_0
|
||||
#undef LOADED_1
|
||||
|
||||
#define LOADED_0 V1
|
||||
BATCH_8:
|
||||
CMP $8, LENGTH
|
||||
BLO BATCH_2
|
||||
VLD1 (PDATA), [LOADED_0.B8]
|
||||
VREV16 LOADED_0.B8, LOADED_0.B8
|
||||
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
|
||||
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
|
||||
ADD $-8, LENGTH
|
||||
ADD $8, PDATA
|
||||
B BATCH_8
|
||||
#undef LOADED_0
|
||||
|
||||
#define LOADED_L R3
|
||||
#define LOADED_H R4
|
||||
BATCH_2:
|
||||
CMP $2, LENGTH
|
||||
BLO BATCH_1
|
||||
MOVBU (PDATA), LOADED_H
|
||||
MOVBU 1(PDATA), LOADED_L
|
||||
LSL $8, LOADED_H
|
||||
ORR LOADED_H, LOADED_L, LOADED_L
|
||||
ADD LOADED_L, RESULT, RESULT
|
||||
ADD $2, PDATA
|
||||
ADD $-2, LENGTH
|
||||
B BATCH_2
|
||||
#undef LOADED_H
|
||||
#undef LOADED_L
|
||||
|
||||
#define LOADED R3
|
||||
BATCH_1:
|
||||
CMP $1, LENGTH
|
||||
BLO COLLECT
|
||||
MOVBU (PDATA), LOADED
|
||||
LSL $8, LOADED
|
||||
ADD LOADED, RESULT, RESULT
|
||||
|
||||
#define EXTRACTED R3
|
||||
COLLECT:
|
||||
VMOV VSUM.S[0], EXTRACTED
|
||||
ADD EXTRACTED, RESULT
|
||||
VMOV VSUM.S[1], EXTRACTED
|
||||
ADD EXTRACTED, RESULT
|
||||
VMOV VSUM.S[2], EXTRACTED
|
||||
ADD EXTRACTED, RESULT
|
||||
VMOV VSUM.S[3], EXTRACTED
|
||||
ADD EXTRACTED, RESULT
|
||||
#undef VSUM
|
||||
#undef PDATA
|
||||
#undef LENGTH
|
||||
|
||||
RETURN:
|
||||
MOVD RESULT, result+16(FP)
|
||||
RET
|
||||
51
internal/clashtcpip/tcpip_arm64_test.go
Normal file
51
internal/clashtcpip/tcpip_arm64_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
func Test_SumNeon(t *testing.T) {
|
||||
if !cpu.ARM64.HasASIMD {
|
||||
t.Skipf("Neon unavailable")
|
||||
}
|
||||
|
||||
bytes := make([]byte, chunkSize)
|
||||
|
||||
for size := 0; size <= chunkSize; size++ {
|
||||
for count := 0; count < chunkCount; count++ {
|
||||
_, err := rand.Reader.Read(bytes[:size])
|
||||
if err != nil {
|
||||
t.Skipf("Rand read failed: %v", err)
|
||||
}
|
||||
|
||||
compat := SumCompat(bytes[:size])
|
||||
neon := SumNeon(bytes[:size])
|
||||
|
||||
if compat != neon {
|
||||
t.Errorf("Sum of length=%d mismatched", size)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Benchmark_SumNeon(b *testing.B) {
|
||||
if !cpu.ARM64.HasASIMD {
|
||||
b.Skipf("Neon unavailable")
|
||||
}
|
||||
|
||||
bytes := make([]byte, chunkSize)
|
||||
|
||||
_, err := rand.Reader.Read(bytes)
|
||||
if err != nil {
|
||||
b.Skipf("Rand read failed: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
SumNeon(bytes)
|
||||
}
|
||||
}
|
||||
14
internal/clashtcpip/tcpip_compat.go
Normal file
14
internal/clashtcpip/tcpip_compat.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package clashtcpip
|
||||
|
||||
func SumCompat(b []byte) (sum uint32) {
|
||||
n := len(b)
|
||||
if n&1 != 0 {
|
||||
n--
|
||||
sum += uint32(b[n]) << 8
|
||||
}
|
||||
|
||||
for i := 0; i < n; i += 2 {
|
||||
sum += (uint32(b[i]) << 8) | uint32(b[i+1])
|
||||
}
|
||||
return
|
||||
}
|
||||
26
internal/clashtcpip/tcpip_compat_test.go
Normal file
26
internal/clashtcpip/tcpip_compat_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
chunkSize = 9000
|
||||
chunkCount = 10
|
||||
)
|
||||
|
||||
func Benchmark_SumCompat(b *testing.B) {
|
||||
bytes := make([]byte, chunkSize)
|
||||
|
||||
_, err := rand.Reader.Read(bytes)
|
||||
if err != nil {
|
||||
b.Skipf("Rand read failed: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
SumCompat(bytes)
|
||||
}
|
||||
}
|
||||
55
internal/clashtcpip/udp.go
Normal file
55
internal/clashtcpip/udp.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package clashtcpip
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
const UDPHeaderSize = 8
|
||||
|
||||
type UDPPacket []byte
|
||||
|
||||
func (p UDPPacket) Length() uint16 {
|
||||
return binary.BigEndian.Uint16(p[4:])
|
||||
}
|
||||
|
||||
func (p UDPPacket) SetLength(length uint16) {
|
||||
binary.BigEndian.PutUint16(p[4:], length)
|
||||
}
|
||||
|
||||
func (p UDPPacket) SourcePort() uint16 {
|
||||
return binary.BigEndian.Uint16(p)
|
||||
}
|
||||
|
||||
func (p UDPPacket) SetSourcePort(port uint16) {
|
||||
binary.BigEndian.PutUint16(p, port)
|
||||
}
|
||||
|
||||
func (p UDPPacket) DestinationPort() uint16 {
|
||||
return binary.BigEndian.Uint16(p[2:])
|
||||
}
|
||||
|
||||
func (p UDPPacket) SetDestinationPort(port uint16) {
|
||||
binary.BigEndian.PutUint16(p[2:], port)
|
||||
}
|
||||
|
||||
func (p UDPPacket) Payload() []byte {
|
||||
return p[UDPHeaderSize:p.Length()]
|
||||
}
|
||||
|
||||
func (p UDPPacket) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(p[6:])
|
||||
}
|
||||
|
||||
func (p UDPPacket) SetChecksum(sum [2]byte) {
|
||||
p[6] = sum[0]
|
||||
p[7] = sum[1]
|
||||
}
|
||||
|
||||
func (p UDPPacket) ResetChecksum(psum uint32) {
|
||||
p.SetChecksum(zeroChecksum)
|
||||
p.SetChecksum(Checksum(psum, p))
|
||||
}
|
||||
|
||||
func (p UDPPacket) Valid() bool {
|
||||
return len(p) >= UDPHeaderSize && uint16(len(p)) >= p.Length()
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
# gtcpip
|
||||
|
||||
Minimal tcpip package kanged from gvisor
|
||||
Version 20241007.0
|
||||
@@ -1,45 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package checksum provides the implementation of the encoding and decoding of
|
||||
// network protocol headers.
|
||||
package checksum
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// Size is the size of a checksum.
|
||||
//
|
||||
// The checksum is held in a uint16 which is 2 bytes.
|
||||
const Size = 2
|
||||
|
||||
// Put puts the checksum in the provided byte slice.
|
||||
func Put(b []byte, xsum uint16) {
|
||||
binary.BigEndian.PutUint16(b, xsum)
|
||||
}
|
||||
|
||||
// Combine combines the two uint16 to form their checksum. This is done
|
||||
// by adding them and the carry.
|
||||
//
|
||||
// Note that checksum a must have been computed on an even number of bytes.
|
||||
func Combine(a, b uint16) uint16 {
|
||||
v := uint32(a) + uint32(b)
|
||||
return uint16(v + v>>16)
|
||||
}
|
||||
|
||||
func ChecksumDefault(buf []byte, initial uint16) uint16 {
|
||||
s, _ := calculateChecksum(buf, false, initial)
|
||||
return s
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
//go:build !amd64
|
||||
|
||||
package checksum
|
||||
|
||||
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
|
||||
// given byte array. This function uses an optimized version of the checksum
|
||||
// algorithm.
|
||||
//
|
||||
// The initial checksum must have been computed on an even number of bytes.
|
||||
func Checksum(buf []byte, initial uint16) uint16 {
|
||||
return ChecksumDefault(buf, initial)
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build amd64
|
||||
|
||||
package checksum
|
||||
|
||||
import "github.com/sagernet/sing-tun/internal/tschecksum"
|
||||
|
||||
func Checksum(buf []byte, initial uint16) uint16 {
|
||||
return tschecksum.Checksum(buf, initial)
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
// Copyright 2023 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package checksum
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Note: odd indicates whether initial is a partial checksum over an odd number
|
||||
// of bytes.
|
||||
func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
|
||||
// Use a larger-than-uint16 accumulator to benefit from parallel summation
|
||||
// as described in RFC 1071 1.2.C.
|
||||
acc := uint64(initial)
|
||||
|
||||
// Handle an odd number of previously-summed bytes, and get the return
|
||||
// value for odd.
|
||||
if odd {
|
||||
acc += uint64(buf[0])
|
||||
buf = buf[1:]
|
||||
}
|
||||
odd = len(buf)&1 != 0
|
||||
|
||||
// Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case
|
||||
// smaller bufs.
|
||||
if len(buf) < 8 {
|
||||
if len(buf) >= 4 {
|
||||
acc += (uint64(buf[0]) << 8) + uint64(buf[1])
|
||||
acc += (uint64(buf[2]) << 8) + uint64(buf[3])
|
||||
buf = buf[4:]
|
||||
}
|
||||
if len(buf) >= 2 {
|
||||
acc += (uint64(buf[0]) << 8) + uint64(buf[1])
|
||||
buf = buf[2:]
|
||||
}
|
||||
if len(buf) >= 1 {
|
||||
acc += uint64(buf[0]) << 8
|
||||
// buf = buf[1:] is skipped because it's unused and nogo will
|
||||
// complain.
|
||||
}
|
||||
return reduce(acc), odd
|
||||
}
|
||||
|
||||
// On little-endian architectures, multi-byte loads from buf will load
|
||||
// bytes in the wrong order. Rather than byte-swap after each load (slow),
|
||||
// we byte-swap the accumulator before summing any bytes and byte-swap it
|
||||
// back before returning, which still produces the correct result as
|
||||
// described in RFC 1071 1.2.B "Byte Order Independence".
|
||||
//
|
||||
// acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We
|
||||
// preserve this property by byte-swapping only the lower 32 bits of acc,
|
||||
// so that additions to acc performed during alignment can't overflow.
|
||||
acc = uint64(bswapIfLittleEndian32(uint32(acc)))
|
||||
|
||||
// Align &buf[0] to an 8-byte boundary.
|
||||
bswapped := false
|
||||
if sliceAddr(buf)&1 != 0 {
|
||||
// Compute the rest of the partial checksum with bytes swapped, and
|
||||
// swap back before returning; see the last paragraph of
|
||||
// RFC 1071 1.2.B.
|
||||
acc = uint64(bits.ReverseBytes32(uint32(acc)))
|
||||
bswapped = true
|
||||
// No `<< 8` here due to the byte swap we just did.
|
||||
acc += uint64(bswapIfLittleEndian16(uint16(buf[0])))
|
||||
buf = buf[1:]
|
||||
}
|
||||
if sliceAddr(buf)&2 != 0 {
|
||||
acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0])))
|
||||
buf = buf[2:]
|
||||
}
|
||||
if sliceAddr(buf)&4 != 0 {
|
||||
acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0])))
|
||||
buf = buf[4:]
|
||||
}
|
||||
|
||||
// Sum 64 bytes at a time. Beyond this point, additions to acc may
|
||||
// overflow, so we have to handle carrying.
|
||||
for len(buf) >= 64 {
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
buf = buf[64:]
|
||||
}
|
||||
|
||||
// Sum the remaining 0-63 bytes.
|
||||
if len(buf) >= 32 {
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
buf = buf[32:]
|
||||
}
|
||||
if len(buf) >= 16 {
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
buf = buf[16:]
|
||||
}
|
||||
if len(buf) >= 8 {
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
buf = buf[8:]
|
||||
}
|
||||
if len(buf) >= 4 {
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
buf = buf[4:]
|
||||
}
|
||||
if len(buf) >= 2 {
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
buf = buf[2:]
|
||||
}
|
||||
if len(buf) >= 1 {
|
||||
// bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8).
|
||||
var carry uint64
|
||||
acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0)
|
||||
acc, _ = bits.Add64(acc, 0, carry)
|
||||
// buf = buf[1:] is skipped because it's unused and nogo will complain.
|
||||
}
|
||||
|
||||
// Reduce the checksum to 16 bits and undo byte swaps before returning.
|
||||
acc16 := bswapIfLittleEndian16(reduce(acc))
|
||||
if bswapped {
|
||||
acc16 = bits.ReverseBytes16(acc16)
|
||||
}
|
||||
return acc16, odd
|
||||
}
|
||||
|
||||
func reduce(acc uint64) uint16 {
|
||||
// Ideally we would do:
|
||||
// return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc)
|
||||
// for more instruction-level parallelism; however, there is no
|
||||
// bits.Add16().
|
||||
acc = (acc >> 32) + (acc & 0xffff_ffff) // at most 0x1_ffff_fffe
|
||||
acc32 := uint32(acc>>32 + acc) // at most 0xffff_ffff
|
||||
acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe
|
||||
return uint16(acc32>>16 + acc32) // at most 0xffff
|
||||
}
|
||||
|
||||
func bswapIfLittleEndian32(val uint32) uint32 {
|
||||
return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:])
|
||||
}
|
||||
|
||||
func bswapIfLittleEndian16(val uint16) uint16 {
|
||||
return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
|
||||
}
|
||||
|
||||
func bswapIfBigEndian16(val uint16) uint16 {
|
||||
return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
|
||||
}
|
||||
|
||||
func sliceAddr(buf []byte) uintptr {
|
||||
return uintptr(unsafe.Pointer(unsafe.SliceData(buf)))
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
// Copyright 2021 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcpip
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error represents an error in the netstack error space.
|
||||
//
|
||||
// The error interface is intentionally omitted to avoid loss of type
|
||||
// information that would occur if these errors were passed as error.
|
||||
type Error interface {
|
||||
isError()
|
||||
|
||||
// IgnoreStats indicates whether this error should be included in failure
|
||||
// counts in tcpip.Stats structs.
|
||||
IgnoreStats() bool
|
||||
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
// ErrBadAddress indicates a bad address was provided.
|
||||
//
|
||||
// +stateify savable
|
||||
type ErrBadAddress struct{}
|
||||
|
||||
func (*ErrBadAddress) isError() {}
|
||||
|
||||
// IgnoreStats implements Error.
|
||||
func (*ErrBadAddress) IgnoreStats() bool {
|
||||
return false
|
||||
}
|
||||
func (*ErrBadAddress) String() string { return "bad address" }
|
||||
@@ -1,107 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package header provides the implementation of the encoding and decoding of
|
||||
// network protocol headers.
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
)
|
||||
|
||||
// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
|
||||
// destination protocol and network address. Pseudo-headers are needed by
|
||||
// transport layers when calculating their own checksum.
|
||||
func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr []byte, dstAddr []byte, totalLen uint16) uint16 {
|
||||
xsum := checksum.Checksum(srcAddr, 0)
|
||||
xsum = checksum.Checksum(dstAddr, xsum)
|
||||
|
||||
// Add the length portion of the checksum to the pseudo-checksum.
|
||||
var tmp [2]byte
|
||||
binary.BigEndian.PutUint16(tmp[:], totalLen)
|
||||
xsum = checksum.Checksum(tmp[:], xsum)
|
||||
|
||||
return checksum.Checksum([]byte{0, uint8(protocol)}, xsum)
|
||||
}
|
||||
|
||||
// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
|
||||
// checksum.
|
||||
//
|
||||
// The value MUST begin at a 2-byte boundary in the original buffer.
|
||||
func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
|
||||
// As per RFC 1071 page 4,
|
||||
// (4) Incremental Update
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// To update the checksum, simply add the differences of the
|
||||
// sixteen bit integers that have been changed. To see why this
|
||||
// works, observe that every 16-bit integer has an additive inverse
|
||||
// and that addition is associative. From this it follows that
|
||||
// given the original value m, the new value m', and the old
|
||||
// checksum C, the new checksum C' is:
|
||||
//
|
||||
// C' = C + (-m) + m' = C + (m' - m)
|
||||
if old == new {
|
||||
return xsum
|
||||
}
|
||||
return checksum.Combine(xsum, checksum.Combine(new, ^old))
|
||||
}
|
||||
|
||||
// checksumUpdate2ByteAlignedAddress updates an address in a calculated
|
||||
// checksum.
|
||||
//
|
||||
// The addresses must have the same length and must contain an even number
|
||||
// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
|
||||
func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
|
||||
const uint16Bytes = 2
|
||||
|
||||
if old.BitLen() != new.BitLen() {
|
||||
panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", old.BitLen()/8, new.BitLen()/8))
|
||||
}
|
||||
|
||||
if oldBytes := old.BitLen() % 16; oldBytes != 0 {
|
||||
panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", oldBytes))
|
||||
}
|
||||
|
||||
oldAddr := old.AsSlice()
|
||||
newAddr := new.AsSlice()
|
||||
|
||||
// As per RFC 1071 page 4,
|
||||
// (4) Incremental Update
|
||||
//
|
||||
// ...
|
||||
//
|
||||
// To update the checksum, simply add the differences of the
|
||||
// sixteen bit integers that have been changed. To see why this
|
||||
// works, observe that every 16-bit integer has an additive inverse
|
||||
// and that addition is associative. From this it follows that
|
||||
// given the original value m, the new value m', and the old
|
||||
// checksum C, the new checksum C' is:
|
||||
//
|
||||
// C' = C + (-m) + m' = C + (m' - m)
|
||||
for len(oldAddr) != 0 {
|
||||
// Convert the 2 byte sequences to uint16 values then apply the increment
|
||||
// update.
|
||||
xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(oldAddr[0])<<8)+uint16(oldAddr[1]), (uint16(newAddr[0])<<8)+uint16(newAddr[1]))
|
||||
oldAddr = oldAddr[uint16Bytes:]
|
||||
newAddr = newAddr[uint16Bytes:]
|
||||
}
|
||||
|
||||
return xsum
|
||||
}
|
||||
@@ -1,192 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
)
|
||||
|
||||
const (
|
||||
dstMAC = 0
|
||||
srcMAC = 6
|
||||
ethType = 12
|
||||
)
|
||||
|
||||
// EthernetFields contains the fields of an ethernet frame header. It is used to
|
||||
// describe the fields of a frame that needs to be encoded.
|
||||
type EthernetFields struct {
|
||||
// SrcAddr is the "MAC source" field of an ethernet frame header.
|
||||
SrcAddr tcpip.LinkAddress
|
||||
|
||||
// DstAddr is the "MAC destination" field of an ethernet frame header.
|
||||
DstAddr tcpip.LinkAddress
|
||||
|
||||
// Type is the "ethertype" field of an ethernet frame header.
|
||||
Type tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
// Ethernet represents an ethernet frame header stored in a byte array.
|
||||
type Ethernet []byte
|
||||
|
||||
const (
|
||||
// EthernetMinimumSize is the minimum size of a valid ethernet frame.
|
||||
EthernetMinimumSize = 14
|
||||
|
||||
// EthernetMaximumSize is the maximum size of a valid ethernet frame.
|
||||
EthernetMaximumSize = 18
|
||||
|
||||
// EthernetAddressSize is the size, in bytes, of an ethernet address.
|
||||
EthernetAddressSize = 6
|
||||
|
||||
// UnspecifiedEthernetAddress is the unspecified ethernet address
|
||||
// (all bits set to 0).
|
||||
UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
|
||||
|
||||
// EthernetBroadcastAddress is an ethernet address that addresses every node
|
||||
// on a local link.
|
||||
EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
|
||||
|
||||
// unicastMulticastFlagMask is the mask of the least significant bit in
|
||||
// the first octet (in network byte order) of an ethernet address that
|
||||
// determines whether the ethernet address is a unicast or multicast. If
|
||||
// the masked bit is a 1, then the address is a multicast, unicast
|
||||
// otherwise.
|
||||
//
|
||||
// See the IEEE Std 802-2001 document for more details. Specifically,
|
||||
// section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
|
||||
// "A 48-bit universal address consists of two parts. The first 24 bits
|
||||
// correspond to the OUI as assigned by the IEEE, expect that the
|
||||
// assignee may set the LSB of the first octet to 1 for group addresses
|
||||
// or set it to 0 for individual addresses."
|
||||
unicastMulticastFlagMask = 1
|
||||
|
||||
// unicastMulticastFlagByteIdx is the byte that holds the
|
||||
// unicast/multicast flag. See unicastMulticastFlagMask.
|
||||
unicastMulticastFlagByteIdx = 0
|
||||
)
|
||||
|
||||
const (
|
||||
// EthernetProtocolAll is a catch-all for all protocols carried inside
|
||||
// an ethernet frame. It is mainly used to create packet sockets that
|
||||
// capture all traffic.
|
||||
EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
|
||||
|
||||
// EthernetProtocolPUP is the PARC Universal Packet protocol ethertype.
|
||||
EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
|
||||
)
|
||||
|
||||
// Ethertypes holds the protocol numbers describing the payload of an ethernet
|
||||
// frame. These types aren't necessarily supported by netstack, but can be used
|
||||
// to catch all traffic of a type via packet endpoints.
|
||||
var Ethertypes = []tcpip.NetworkProtocolNumber{
|
||||
EthernetProtocolAll,
|
||||
EthernetProtocolPUP,
|
||||
}
|
||||
|
||||
// SourceAddress returns the "MAC source" field of the ethernet frame header.
|
||||
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
|
||||
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
|
||||
}
|
||||
|
||||
// DestinationAddress returns the "MAC destination" field of the ethernet frame
|
||||
// header.
|
||||
func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
|
||||
return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
|
||||
}
|
||||
|
||||
// Type returns the "ethertype" field of the ethernet frame header.
|
||||
func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
|
||||
return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the ethernet frame header.
|
||||
func (b Ethernet) Encode(e *EthernetFields) {
|
||||
binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
|
||||
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
|
||||
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
|
||||
}
|
||||
|
||||
// IsMulticastEthernetAddress returns true if the address is a multicast
|
||||
// ethernet address.
|
||||
func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
|
||||
if len(addr) != EthernetAddressSize {
|
||||
return false
|
||||
}
|
||||
|
||||
return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
|
||||
}
|
||||
|
||||
// IsValidUnicastEthernetAddress returns true if the address is a unicast
|
||||
// ethernet address.
|
||||
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
|
||||
if len(addr) != EthernetAddressSize {
|
||||
return false
|
||||
}
|
||||
|
||||
if addr == UnspecifiedEthernetAddress {
|
||||
return false
|
||||
}
|
||||
|
||||
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
|
||||
// for a multicast IPv4 address.
|
||||
//
|
||||
// addr MUST be a multicast IPv4 address.
|
||||
func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
|
||||
var linkAddrBytes [EthernetAddressSize]byte
|
||||
// RFC 1112 Host Extensions for IP Multicasting
|
||||
//
|
||||
// 6.4. Extensions to an Ethernet Local Network Module:
|
||||
//
|
||||
// An IP host group address is mapped to an Ethernet multicast
|
||||
// address by placing the low-order 23-bits of the IP address
|
||||
// into the low-order 23 bits of the Ethernet multicast address
|
||||
// 01-00-5E-00-00-00 (hex).
|
||||
addrBytes := addr.As4()
|
||||
linkAddrBytes[0] = 0x1
|
||||
linkAddrBytes[2] = 0x5e
|
||||
linkAddrBytes[3] = addrBytes[1] & 0x7F
|
||||
copy(linkAddrBytes[4:], addrBytes[IPv4AddressSize-2:])
|
||||
return tcpip.LinkAddress(linkAddrBytes[:])
|
||||
}
|
||||
|
||||
// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
|
||||
// for a multicast IPv6 address.
|
||||
//
|
||||
// addr MUST be a multicast IPv6 address.
|
||||
func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
|
||||
// RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
|
||||
//
|
||||
// 7. Address Mapping -- Multicast
|
||||
//
|
||||
// An IPv6 packet with a multicast destination address DST,
|
||||
// consisting of the sixteen octets DST[1] through DST[16], is
|
||||
// transmitted to the Ethernet multicast address whose first
|
||||
// two octets are the value 3333 hexadecimal and whose last
|
||||
// four octets are the last four octets of DST.
|
||||
addrBytes := addr.As16()
|
||||
linkAddrBytes := []byte(addrBytes[IPv6AddressSize-EthernetAddressSize:])
|
||||
linkAddrBytes[0] = 0x33
|
||||
linkAddrBytes[1] = 0x33
|
||||
return tcpip.LinkAddress(linkAddrBytes[:])
|
||||
}
|
||||
@@ -1,228 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
)
|
||||
|
||||
// ICMPv4 represents an ICMPv4 header stored in a byte array.
|
||||
type ICMPv4 []byte
|
||||
|
||||
const (
|
||||
// ICMPv4PayloadOffset defines the start of ICMP payload.
|
||||
ICMPv4PayloadOffset = 8
|
||||
|
||||
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
|
||||
ICMPv4MinimumSize = 8
|
||||
|
||||
// ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an
|
||||
// errant packet's transport layer that an ICMP error type packet should
|
||||
// attempt to send as per RFC 792 (see each type) and RFC 1122
|
||||
// section 3.2.2 which states:
|
||||
// Every ICMP error message includes the Internet header and at
|
||||
// least the first 8 data octets of the datagram that triggered
|
||||
// the error; more than 8 octets MAY be sent; this header and data
|
||||
// MUST be unchanged from the received datagram.
|
||||
//
|
||||
// RFC 792 shows:
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Type | Code | Checksum |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | unused |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Internet Header + 64 bits of Original Data Datagram |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
ICMPv4MinimumErrorPayloadSize = 8
|
||||
|
||||
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
|
||||
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
|
||||
|
||||
// icmpv4ChecksumOffset is the offset of the checksum field
|
||||
// in an ICMPv4 message.
|
||||
icmpv4ChecksumOffset = 2
|
||||
|
||||
// icmpv4MTUOffset is the offset of the MTU field
|
||||
// in an ICMPv4FragmentationNeeded message.
|
||||
icmpv4MTUOffset = 6
|
||||
|
||||
// icmpv4IdentOffset is the offset of the ident field
|
||||
// in an ICMPv4EchoRequest/Reply message.
|
||||
icmpv4IdentOffset = 4
|
||||
|
||||
// icmpv4PointerOffset is the offset of the pointer field
|
||||
// in an ICMPv4ParamProblem message.
|
||||
icmpv4PointerOffset = 4
|
||||
|
||||
// icmpv4SequenceOffset is the offset of the sequence field
|
||||
// in an ICMPv4EchoRequest/Reply message.
|
||||
icmpv4SequenceOffset = 6
|
||||
)
|
||||
|
||||
// ICMPv4Type is the ICMP type field described in RFC 792.
|
||||
type ICMPv4Type byte
|
||||
|
||||
// ICMPv4Code is the ICMP code field described in RFC 792.
|
||||
type ICMPv4Code byte
|
||||
|
||||
// Typical values of ICMPv4Type defined in RFC 792.
|
||||
const (
|
||||
ICMPv4EchoReply ICMPv4Type = 0
|
||||
ICMPv4DstUnreachable ICMPv4Type = 3
|
||||
ICMPv4SrcQuench ICMPv4Type = 4
|
||||
ICMPv4Redirect ICMPv4Type = 5
|
||||
ICMPv4Echo ICMPv4Type = 8
|
||||
ICMPv4TimeExceeded ICMPv4Type = 11
|
||||
ICMPv4ParamProblem ICMPv4Type = 12
|
||||
ICMPv4Timestamp ICMPv4Type = 13
|
||||
ICMPv4TimestampReply ICMPv4Type = 14
|
||||
ICMPv4InfoRequest ICMPv4Type = 15
|
||||
ICMPv4InfoReply ICMPv4Type = 16
|
||||
)
|
||||
|
||||
// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
|
||||
const (
|
||||
ICMPv4TTLExceeded ICMPv4Code = 0
|
||||
ICMPv4ReassemblyTimeout ICMPv4Code = 1
|
||||
)
|
||||
|
||||
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792,
|
||||
// RFC 1122 section 3.2.2.1 and RFC 1812 section 5.2.7.1.
|
||||
const (
|
||||
ICMPv4NetUnreachable ICMPv4Code = 0
|
||||
ICMPv4HostUnreachable ICMPv4Code = 1
|
||||
ICMPv4ProtoUnreachable ICMPv4Code = 2
|
||||
ICMPv4PortUnreachable ICMPv4Code = 3
|
||||
ICMPv4FragmentationNeeded ICMPv4Code = 4
|
||||
ICMPv4SourceRouteFailed ICMPv4Code = 5
|
||||
ICMPv4DestinationNetworkUnknown ICMPv4Code = 6
|
||||
ICMPv4DestinationHostUnknown ICMPv4Code = 7
|
||||
ICMPv4SourceHostIsolated ICMPv4Code = 8
|
||||
ICMPv4NetProhibited ICMPv4Code = 9
|
||||
ICMPv4HostProhibited ICMPv4Code = 10
|
||||
ICMPv4NetUnreachableForTos ICMPv4Code = 11
|
||||
ICMPv4HostUnreachableForTos ICMPv4Code = 12
|
||||
ICMPv4AdminProhibited ICMPv4Code = 13
|
||||
ICMPv4HostPrecedenceViolation ICMPv4Code = 14
|
||||
ICMPv4PrecedenceCutInEffect ICMPv4Code = 15
|
||||
)
|
||||
|
||||
// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed.
|
||||
const ICMPv4UnusedCode ICMPv4Code = 0
|
||||
|
||||
// Type is the ICMP type field.
|
||||
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
|
||||
|
||||
// SetType sets the ICMP type field.
|
||||
func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
|
||||
|
||||
// Code is the ICMP code field. Its meaning depends on the value of Type.
|
||||
func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
|
||||
|
||||
// SetCode sets the ICMP code field.
|
||||
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
|
||||
|
||||
// Pointer returns the pointer field in a Parameter Problem packet.
|
||||
func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] }
|
||||
|
||||
// SetPointer sets the pointer field in a Parameter Problem packet.
|
||||
func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
|
||||
|
||||
// Checksum is the ICMP checksum field.
|
||||
func (b ICMPv4) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
|
||||
}
|
||||
|
||||
// SetChecksum sets the ICMP checksum field.
|
||||
func (b ICMPv4) SetChecksum(cs uint16) {
|
||||
checksum.Put(b[icmpv4ChecksumOffset:], cs)
|
||||
}
|
||||
|
||||
// SourcePort implements Transport.SourcePort.
|
||||
func (ICMPv4) SourcePort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// DestinationPort implements Transport.DestinationPort.
|
||||
func (ICMPv4) DestinationPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// SetSourcePort implements Transport.SetSourcePort.
|
||||
func (ICMPv4) SetSourcePort(uint16) {
|
||||
}
|
||||
|
||||
// SetDestinationPort implements Transport.SetDestinationPort.
|
||||
func (ICMPv4) SetDestinationPort(uint16) {
|
||||
}
|
||||
|
||||
// Payload implements Transport.Payload.
|
||||
func (b ICMPv4) Payload() []byte {
|
||||
return b[ICMPv4PayloadOffset:]
|
||||
}
|
||||
|
||||
// MTU retrieves the MTU field from an ICMPv4 message.
|
||||
func (b ICMPv4) MTU() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
|
||||
}
|
||||
|
||||
// SetMTU sets the MTU field from an ICMPv4 message.
|
||||
func (b ICMPv4) SetMTU(mtu uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
|
||||
}
|
||||
|
||||
// Ident retrieves the Ident field from an ICMPv4 message.
|
||||
func (b ICMPv4) Ident() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
|
||||
}
|
||||
|
||||
// SetIdent sets the Ident field from an ICMPv4 message.
|
||||
func (b ICMPv4) SetIdent(ident uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
|
||||
}
|
||||
|
||||
// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum.
|
||||
func (b ICMPv4) SetIdentWithChecksumUpdate(new uint16) {
|
||||
old := b.Ident()
|
||||
b.SetIdent(new)
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
|
||||
}
|
||||
|
||||
// Sequence retrieves the Sequence field from an ICMPv4 message.
|
||||
func (b ICMPv4) Sequence() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
|
||||
}
|
||||
|
||||
// SetSequence sets the Sequence field from an ICMPv4 message.
|
||||
func (b ICMPv4) SetSequence(sequence uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
|
||||
}
|
||||
|
||||
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
|
||||
// and payload.
|
||||
func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 {
|
||||
xsum := payloadCsum
|
||||
|
||||
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
|
||||
xsum = checksum.Checksum(h[:2], xsum)
|
||||
xsum = checksum.Checksum(h[4:], xsum)
|
||||
|
||||
return ^xsum
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
)
|
||||
|
||||
// ICMPv6 represents an ICMPv6 header stored in a byte array.
|
||||
type ICMPv6 []byte
|
||||
|
||||
const (
|
||||
// ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the
|
||||
// sum of the size of the ICMPv6 Type, Code and Checksum fields, as
|
||||
// per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6
|
||||
// message body begins.
|
||||
ICMPv6HeaderSize = 4
|
||||
|
||||
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
|
||||
ICMPv6MinimumSize = 8
|
||||
|
||||
// ICMPv6PayloadOffset is the offset of the payload in an
|
||||
// ICMP packet.
|
||||
ICMPv6PayloadOffset = 8
|
||||
|
||||
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
|
||||
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
|
||||
|
||||
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
|
||||
// neighbor solicitation packet.
|
||||
ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize
|
||||
|
||||
// ICMPv6NeighborAdvertMinimumSize is the minimum size of a
|
||||
// neighbor advertisement packet.
|
||||
ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
|
||||
|
||||
// ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
|
||||
ICMPv6EchoMinimumSize = 8
|
||||
|
||||
// ICMPv6ErrorHeaderSize is the size of an ICMP error packet header,
|
||||
// as per RFC 4443, Appendix A, item 4 and the errata.
|
||||
// ... all ICMP error messages shall have exactly
|
||||
// 32 bits of type-specific data, so that receivers can reliably find
|
||||
// the embedded invoking packet even when they don't recognize the
|
||||
// ICMP message Type.
|
||||
ICMPv6ErrorHeaderSize = 8
|
||||
|
||||
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
|
||||
// destination unreachable packet.
|
||||
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
|
||||
|
||||
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
|
||||
// packet-too-big packet.
|
||||
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
|
||||
|
||||
// ICMPv6ChecksumOffset is the offset of the checksum field
|
||||
// in an ICMPv6 message.
|
||||
ICMPv6ChecksumOffset = 2
|
||||
|
||||
// icmpv6PointerOffset is the offset of the pointer
|
||||
// in an ICMPv6 Parameter problem message.
|
||||
icmpv6PointerOffset = 4
|
||||
|
||||
// icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
|
||||
// PacketTooBig message.
|
||||
icmpv6MTUOffset = 4
|
||||
|
||||
// icmpv6IdentOffset is the offset of the ident field
|
||||
// in a ICMPv6 Echo Request/Reply message.
|
||||
icmpv6IdentOffset = 4
|
||||
|
||||
// icmpv6SequenceOffset is the offset of the sequence field
|
||||
// in a ICMPv6 Echo Request/Reply message.
|
||||
icmpv6SequenceOffset = 6
|
||||
|
||||
// NDPHopLimit is the expected IP hop limit value of 255 for received
|
||||
// NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
|
||||
// 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
|
||||
// drop the NDP packet. All outgoing NDP packets must use this value for
|
||||
// its IP hop limit field.
|
||||
NDPHopLimit = 255
|
||||
)
|
||||
|
||||
// ICMPv6Type is the ICMP type field described in RFC 4443.
|
||||
type ICMPv6Type byte
|
||||
|
||||
// Values for use in the Type field of ICMPv6 packet from RFC 4433.
|
||||
const (
|
||||
ICMPv6DstUnreachable ICMPv6Type = 1
|
||||
ICMPv6PacketTooBig ICMPv6Type = 2
|
||||
ICMPv6TimeExceeded ICMPv6Type = 3
|
||||
ICMPv6ParamProblem ICMPv6Type = 4
|
||||
ICMPv6EchoRequest ICMPv6Type = 128
|
||||
ICMPv6EchoReply ICMPv6Type = 129
|
||||
|
||||
// Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
|
||||
|
||||
ICMPv6RouterSolicit ICMPv6Type = 133
|
||||
ICMPv6RouterAdvert ICMPv6Type = 134
|
||||
ICMPv6NeighborSolicit ICMPv6Type = 135
|
||||
ICMPv6NeighborAdvert ICMPv6Type = 136
|
||||
ICMPv6RedirectMsg ICMPv6Type = 137
|
||||
|
||||
// Multicast Listener Discovery (MLD) messages, see RFC 2710.
|
||||
|
||||
ICMPv6MulticastListenerQuery ICMPv6Type = 130
|
||||
ICMPv6MulticastListenerReport ICMPv6Type = 131
|
||||
ICMPv6MulticastListenerDone ICMPv6Type = 132
|
||||
|
||||
// Multicast Listener Discovert Version 2 (MLDv2) messages, see RFC 3810.
|
||||
|
||||
ICMPv6MulticastListenerV2Report ICMPv6Type = 143
|
||||
)
|
||||
|
||||
// IsErrorType returns true if the receiver is an ICMP error type.
|
||||
func (typ ICMPv6Type) IsErrorType() bool {
|
||||
// Per RFC 4443 section 2.1:
|
||||
// ICMPv6 messages are grouped into two classes: error messages and
|
||||
// informational messages. Error messages are identified as such by a
|
||||
// zero in the high-order bit of their message Type field values. Thus,
|
||||
// error messages have message types from 0 to 127; informational
|
||||
// messages have message types from 128 to 255.
|
||||
return typ&0x80 == 0
|
||||
}
|
||||
|
||||
// ICMPv6Code is the ICMP Code field described in RFC 4443.
|
||||
type ICMPv6Code byte
|
||||
|
||||
// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
|
||||
// section 3.1.
|
||||
const (
|
||||
ICMPv6NetworkUnreachable ICMPv6Code = 0
|
||||
ICMPv6Prohibited ICMPv6Code = 1
|
||||
ICMPv6BeyondScope ICMPv6Code = 2
|
||||
ICMPv6AddressUnreachable ICMPv6Code = 3
|
||||
ICMPv6PortUnreachable ICMPv6Code = 4
|
||||
ICMPv6Policy ICMPv6Code = 5
|
||||
ICMPv6RejectRoute ICMPv6Code = 6
|
||||
)
|
||||
|
||||
// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3.
|
||||
const (
|
||||
ICMPv6HopLimitExceeded ICMPv6Code = 0
|
||||
ICMPv6ReassemblyTimeout ICMPv6Code = 1
|
||||
)
|
||||
|
||||
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
|
||||
const (
|
||||
// ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
|
||||
ICMPv6ErroneousHeader ICMPv6Code = 0
|
||||
|
||||
// ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
|
||||
ICMPv6UnknownHeader ICMPv6Code = 1
|
||||
|
||||
// ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
|
||||
ICMPv6UnknownOption ICMPv6Code = 2
|
||||
)
|
||||
|
||||
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
|
||||
// the code field. (Types not mentioned above.)
|
||||
const ICMPv6UnusedCode ICMPv6Code = 0
|
||||
|
||||
// Type is the ICMP type field.
|
||||
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
|
||||
|
||||
// SetType sets the ICMP type field.
|
||||
func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
|
||||
|
||||
// Code is the ICMP code field. Its meaning depends on the value of Type.
|
||||
func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
|
||||
|
||||
// SetCode sets the ICMP code field.
|
||||
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
|
||||
|
||||
// TypeSpecific returns the type specific data field.
|
||||
func (b ICMPv6) TypeSpecific() uint32 {
|
||||
return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
|
||||
}
|
||||
|
||||
// SetTypeSpecific sets the type specific data field.
|
||||
func (b ICMPv6) SetTypeSpecific(val uint32) {
|
||||
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
|
||||
}
|
||||
|
||||
// Checksum is the ICMP checksum field.
|
||||
func (b ICMPv6) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:])
|
||||
}
|
||||
|
||||
// SetChecksum sets the ICMP checksum field.
|
||||
func (b ICMPv6) SetChecksum(cs uint16) {
|
||||
checksum.Put(b[ICMPv6ChecksumOffset:], cs)
|
||||
}
|
||||
|
||||
// SourcePort implements Transport.SourcePort.
|
||||
func (ICMPv6) SourcePort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// DestinationPort implements Transport.DestinationPort.
|
||||
func (ICMPv6) DestinationPort() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// SetSourcePort implements Transport.SetSourcePort.
|
||||
func (ICMPv6) SetSourcePort(uint16) {
|
||||
}
|
||||
|
||||
// SetDestinationPort implements Transport.SetDestinationPort.
|
||||
func (ICMPv6) SetDestinationPort(uint16) {
|
||||
}
|
||||
|
||||
// MTU retrieves the MTU field from an ICMPv6 message.
|
||||
func (b ICMPv6) MTU() uint32 {
|
||||
return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
|
||||
}
|
||||
|
||||
// SetMTU sets the MTU field from an ICMPv6 message.
|
||||
func (b ICMPv6) SetMTU(mtu uint32) {
|
||||
binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
|
||||
}
|
||||
|
||||
// Ident retrieves the Ident field from an ICMPv6 message.
|
||||
func (b ICMPv6) Ident() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
|
||||
}
|
||||
|
||||
// SetIdent sets the Ident field from an ICMPv6 message.
|
||||
func (b ICMPv6) SetIdent(ident uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
|
||||
}
|
||||
|
||||
// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum.
|
||||
func (b ICMPv6) SetIdentWithChecksumUpdate(new uint16) {
|
||||
old := b.Ident()
|
||||
b.SetIdent(new)
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
|
||||
}
|
||||
|
||||
// Sequence retrieves the Sequence field from an ICMPv6 message.
|
||||
func (b ICMPv6) Sequence() uint16 {
|
||||
return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
|
||||
}
|
||||
|
||||
// SetSequence sets the Sequence field from an ICMPv6 message.
|
||||
func (b ICMPv6) SetSequence(sequence uint16) {
|
||||
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
|
||||
}
|
||||
|
||||
// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
|
||||
// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
|
||||
func (b ICMPv6) MessageBody() []byte {
|
||||
return b[ICMPv6HeaderSize:]
|
||||
}
|
||||
|
||||
// Payload implements Transport.Payload.
|
||||
func (b ICMPv6) Payload() []byte {
|
||||
return b[ICMPv6PayloadOffset:]
|
||||
}
|
||||
|
||||
// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
|
||||
type ICMPv6ChecksumParams struct {
|
||||
Header ICMPv6
|
||||
Src tcpip.Address
|
||||
Dst tcpip.Address
|
||||
PayloadCsum uint16
|
||||
PayloadLen int
|
||||
}
|
||||
|
||||
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
|
||||
// IPv6 src/dst addresses and the payload.
|
||||
func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
|
||||
h := params.Header
|
||||
|
||||
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src.AsSlice(), params.Dst.AsSlice(), uint16(len(h)+params.PayloadLen))
|
||||
xsum = checksum.Combine(xsum, params.PayloadCsum)
|
||||
|
||||
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
|
||||
xsum = checksum.Checksum(h[:2], xsum)
|
||||
xsum = checksum.Checksum(h[4:], xsum)
|
||||
|
||||
return ^xsum
|
||||
}
|
||||
|
||||
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
|
||||
// updated address in the pseudo header.
|
||||
func (b ICMPv6) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address) {
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), old, new))
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
tcpip "github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxIPPacketSize is the maximum supported IP packet size, excluding
|
||||
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
|
||||
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
|
||||
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
|
||||
// m is the minimum IPv6 header size; we leave room for some potential
|
||||
// IP options.
|
||||
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
|
||||
)
|
||||
|
||||
// Transport offers generic methods to query and/or update the fields of the
|
||||
// header of a transport protocol buffer.
|
||||
type Transport interface {
|
||||
// SourcePort returns the value of the "source port" field.
|
||||
SourcePort() uint16
|
||||
|
||||
// Destination returns the value of the "destination port" field.
|
||||
DestinationPort() uint16
|
||||
|
||||
// Checksum returns the value of the "checksum" field.
|
||||
Checksum() uint16
|
||||
|
||||
// SetSourcePort sets the value of the "source port" field.
|
||||
SetSourcePort(uint16)
|
||||
|
||||
// SetDestinationPort sets the value of the "destination port" field.
|
||||
SetDestinationPort(uint16)
|
||||
|
||||
// SetChecksum sets the value of the "checksum" field.
|
||||
SetChecksum(uint16)
|
||||
|
||||
// Payload returns the data carried in the transport buffer.
|
||||
Payload() []byte
|
||||
}
|
||||
|
||||
// ChecksummableTransport is a Transport that supports checksumming.
|
||||
type ChecksummableTransport interface {
|
||||
Transport
|
||||
|
||||
// SetSourcePortWithChecksumUpdate sets the source port and updates
|
||||
// the checksum.
|
||||
//
|
||||
// The receiver's checksum must be a fully calculated checksum.
|
||||
SetSourcePortWithChecksumUpdate(port uint16)
|
||||
|
||||
// SetDestinationPortWithChecksumUpdate sets the destination port and updates
|
||||
// the checksum.
|
||||
//
|
||||
// The receiver's checksum must be a fully calculated checksum.
|
||||
SetDestinationPortWithChecksumUpdate(port uint16)
|
||||
|
||||
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
|
||||
// updated address in the pseudo header.
|
||||
//
|
||||
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
|
||||
// fully calculated checksum. Otherwise, it is assumed to hold a partially
|
||||
// calculated checksum which only reflects the pseudo header.
|
||||
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
|
||||
}
|
||||
|
||||
// Network offers generic methods to query and/or update the fields of the
|
||||
// header of a network protocol buffer.
|
||||
type Network interface {
|
||||
// SourceAddress returns the value of the "source address" field.
|
||||
SourceAddress() tcpip.Address
|
||||
|
||||
// DestinationAddress returns the value of the "destination address"
|
||||
// field.
|
||||
DestinationAddress() tcpip.Address
|
||||
|
||||
DestinationAddr() netip.Addr
|
||||
|
||||
// Checksum returns the value of the "checksum" field.
|
||||
Checksum() uint16
|
||||
|
||||
// SetSourceAddress sets the value of the "source address" field.
|
||||
SetSourceAddress(tcpip.Address)
|
||||
|
||||
// SetDestinationAddress sets the value of the "destination address"
|
||||
// field.
|
||||
SetDestinationAddress(tcpip.Address)
|
||||
|
||||
SetDestinationAddr(addr netip.Addr)
|
||||
|
||||
// SetChecksum sets the value of the "checksum" field.
|
||||
SetChecksum(uint16)
|
||||
|
||||
// TransportProtocol returns the number of the transport protocol
|
||||
// stored in the payload.
|
||||
TransportProtocol() tcpip.TransportProtocolNumber
|
||||
|
||||
// Payload returns a byte slice containing the payload of the network
|
||||
// packet.
|
||||
Payload() []byte
|
||||
|
||||
// TOS returns the values of the "type of service" and "flow label" fields.
|
||||
TOS() (uint8, uint32)
|
||||
|
||||
// SetTOS sets the values of the "type of service" and "flow label" fields.
|
||||
SetTOS(t uint8, l uint32)
|
||||
}
|
||||
|
||||
// ChecksummableNetwork is a Network that supports checksumming.
|
||||
type ChecksummableNetwork interface {
|
||||
Network
|
||||
|
||||
// SetSourceAddressAndChecksum sets the source address and updates the
|
||||
// checksum to reflect the new address.
|
||||
SetSourceAddressWithChecksumUpdate(tcpip.Address)
|
||||
|
||||
// SetDestinationAddressAndChecksum sets the destination address and
|
||||
// updates the checksum to reflect the new address.
|
||||
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,578 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
)
|
||||
|
||||
const (
|
||||
versTCFL = 0
|
||||
// IPv6PayloadLenOffset is the offset of the PayloadLength field in
|
||||
// IPv6 header.
|
||||
IPv6PayloadLenOffset = 4
|
||||
// IPv6NextHeaderOffset is the offset of the NextHeader field in
|
||||
// IPv6 header.
|
||||
IPv6NextHeaderOffset = 6
|
||||
hopLimit = 7
|
||||
v6SrcAddr = 8
|
||||
v6DstAddr = v6SrcAddr + IPv6AddressSize
|
||||
|
||||
// IPv6FixedHeaderSize is the size of the fixed header.
|
||||
IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
|
||||
)
|
||||
|
||||
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
|
||||
// fields of a packet that needs to be encoded.
|
||||
type IPv6Fields struct {
|
||||
// TrafficClass is the "traffic class" field of an IPv6 packet.
|
||||
TrafficClass uint8
|
||||
|
||||
// FlowLabel is the "flow label" field of an IPv6 packet.
|
||||
FlowLabel uint32
|
||||
|
||||
// PayloadLength is the "payload length" field of an IPv6 packet, including
|
||||
// the length of all extension headers.
|
||||
PayloadLength uint16
|
||||
|
||||
// TransportProtocol is the transport layer protocol number. Serialized in the
|
||||
// last "next header" field of the IPv6 header + extension headers.
|
||||
TransportProtocol tcpip.TransportProtocolNumber
|
||||
|
||||
// HopLimit is the "Hop Limit" field of an IPv6 packet.
|
||||
HopLimit uint8
|
||||
|
||||
// SrcAddr is the "source ip address" of an IPv6 packet.
|
||||
SrcAddr netip.Addr
|
||||
|
||||
// DstAddr is the "destination ip address" of an IPv6 packet.
|
||||
DstAddr netip.Addr
|
||||
|
||||
// ExtensionHeaders are the extension headers following the IPv6 header.
|
||||
ExtensionHeaders IPv6ExtHdrSerializer
|
||||
}
|
||||
|
||||
// IPv6 represents an ipv6 header stored in a byte array.
|
||||
// Most of the methods of IPv6 access to the underlying slice without
|
||||
// checking the boundaries and could panic because of 'index out of range'.
|
||||
// Always call IsValid() to validate an instance of IPv6 before using other methods.
|
||||
type IPv6 []byte
|
||||
|
||||
const (
|
||||
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
|
||||
IPv6MinimumSize = IPv6FixedHeaderSize
|
||||
|
||||
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
|
||||
IPv6AddressSize = 16
|
||||
|
||||
// IPv6AddressSizeBits is the size, in bits, of an IPv6 address.
|
||||
IPv6AddressSizeBits = 128
|
||||
|
||||
// IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per
|
||||
// RFC 8200 Section 4.5.
|
||||
IPv6MaximumPayloadSize = 65535
|
||||
|
||||
// IPv6ProtocolNumber is IPv6's network protocol number.
|
||||
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
|
||||
|
||||
// IPv6Version is the version of the ipv6 protocol.
|
||||
IPv6Version = 6
|
||||
|
||||
// IIDSize is the size of an interface identifier (IID), in bytes, as
|
||||
// defined by RFC 4291 section 2.5.1.
|
||||
IIDSize = 8
|
||||
|
||||
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
|
||||
// section 5:
|
||||
// IPv6 requires that every link in the Internet have an MTU of 1280 octets
|
||||
// or greater. This is known as the IPv6 minimum link MTU.
|
||||
IPv6MinimumMTU = 1280
|
||||
|
||||
// IIDOffsetInIPv6Address is the offset, in bytes, from the start
|
||||
// of an IPv6 address to the beginning of the interface identifier
|
||||
// (IID) for auto-generated addresses. That is, all bytes before
|
||||
// the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
|
||||
// bytes including and after the IIDOffsetInIPv6Address-th byte are
|
||||
// for the IID.
|
||||
IIDOffsetInIPv6Address = 8
|
||||
|
||||
// OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
|
||||
// for the secret key used to generate an opaque interface identifier as
|
||||
// outlined by RFC 7217.
|
||||
OpaqueIIDSecretKeyMinBytes = 16
|
||||
|
||||
// ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
|
||||
// is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
|
||||
ipv6MulticastAddressScopeByteIdx = 1
|
||||
|
||||
// ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
|
||||
// within the byte holding the field, as per RFC 4291 section 2.7.
|
||||
ipv6MulticastAddressScopeMask = 0xF
|
||||
)
|
||||
|
||||
var (
|
||||
// IPv6AllNodesMulticastAddress is a link-local multicast group that
|
||||
// all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
|
||||
// destined to this address will reach all nodes on a link.
|
||||
//
|
||||
// The address is ff02::1.
|
||||
IPv6AllNodesMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
|
||||
|
||||
// IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local
|
||||
// multicast group that all IPv6 routers MUST join, as per RFC 4291, section
|
||||
// 2.8. Packets destined to this address will reach the router on an
|
||||
// interface.
|
||||
//
|
||||
// The address is ff01::2.
|
||||
IPv6AllRoutersInterfaceLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
|
||||
|
||||
// IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group
|
||||
// that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
|
||||
// destined to this address will reach all routers on a link.
|
||||
//
|
||||
// The address is ff02::2.
|
||||
IPv6AllRoutersLinkLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
|
||||
|
||||
// IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group
|
||||
// that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
|
||||
// destined to this address will reach all routers in a site.
|
||||
//
|
||||
// The address is ff05::2.
|
||||
IPv6AllRoutersSiteLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
|
||||
|
||||
// IPv6Loopback is the IPv6 Loopback address.
|
||||
IPv6Loopback = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
|
||||
|
||||
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
|
||||
// known as the unspecified address.
|
||||
IPv6Any = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
|
||||
)
|
||||
|
||||
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
|
||||
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
|
||||
// be contained within this subnet.
|
||||
var IPv6EmptySubnet = tcpip.AddressWithPrefix{
|
||||
Address: IPv6Any,
|
||||
PrefixLen: 0,
|
||||
}.Subnet()
|
||||
|
||||
// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
|
||||
// by RFC 4291 section 2.5.5.
|
||||
var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}),
|
||||
PrefixLen: 96,
|
||||
}.Subnet()
|
||||
|
||||
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
|
||||
// by RFC 4291 section 2.5.6.
|
||||
//
|
||||
// The prefix is fe80::/64
|
||||
var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom16([16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
|
||||
PrefixLen: 64,
|
||||
}
|
||||
|
||||
// PayloadLength returns the value of the "payload length" field of the ipv6
|
||||
// header.
|
||||
func (b IPv6) PayloadLength() uint16 {
|
||||
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
|
||||
}
|
||||
|
||||
// HopLimit returns the value of the "Hop Limit" field of the ipv6 header.
|
||||
func (b IPv6) HopLimit() uint8 {
|
||||
return b[hopLimit]
|
||||
}
|
||||
|
||||
// NextHeader returns the value of the "next header" field of the ipv6 header.
|
||||
func (b IPv6) NextHeader() uint8 {
|
||||
return b[IPv6NextHeaderOffset]
|
||||
}
|
||||
|
||||
// TransportProtocol implements Network.TransportProtocol.
|
||||
func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
|
||||
return tcpip.TransportProtocolNumber(b.NextHeader())
|
||||
}
|
||||
|
||||
// Payload implements Network.Payload.
|
||||
func (b IPv6) Payload() []byte {
|
||||
return b[IPv6MinimumSize:][:b.PayloadLength()]
|
||||
}
|
||||
|
||||
// SourceAddress returns the "source address" field of the ipv6 header.
|
||||
func (b IPv6) SourceAddress() tcpip.Address {
|
||||
return tcpip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize]))
|
||||
}
|
||||
|
||||
// DestinationAddress returns the "destination address" field of the ipv6
|
||||
// header.
|
||||
func (b IPv6) DestinationAddress() tcpip.Address {
|
||||
return tcpip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize]))
|
||||
}
|
||||
|
||||
// SourceAddressSlice returns the "source address" field of the ipv6 header as a
|
||||
// byte slice.
|
||||
func (b IPv6) SourceAddressSlice() []byte {
|
||||
return []byte(b[v6SrcAddr:][:IPv6AddressSize])
|
||||
}
|
||||
|
||||
// DestinationAddressSlice returns the "destination address" field of the ipv6
|
||||
// header as a byte slice.
|
||||
func (b IPv6) DestinationAddressSlice() []byte {
|
||||
return []byte(b[v6DstAddr:][:IPv6AddressSize])
|
||||
}
|
||||
|
||||
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
|
||||
// checksum, it just returns 0.
|
||||
func (IPv6) Checksum() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
|
||||
func (b IPv6) TOS() (uint8, uint32) {
|
||||
v := binary.BigEndian.Uint32(b[versTCFL:])
|
||||
return uint8(v >> 20), v & 0xfffff
|
||||
}
|
||||
|
||||
// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
|
||||
func (b IPv6) SetTOS(t uint8, l uint32) {
|
||||
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
|
||||
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
|
||||
}
|
||||
|
||||
// SetPayloadLength sets the "payload length" field of the ipv6 header.
|
||||
func (b IPv6) SetPayloadLength(payloadLength uint16) {
|
||||
binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
|
||||
}
|
||||
|
||||
// SetSourceAddress sets the "source address" field of the ipv6 header.
|
||||
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
|
||||
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
// SetDestinationAddress sets the "destination address" field of the ipv6
|
||||
// header.
|
||||
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
|
||||
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
// SetHopLimit sets the value of the "Hop Limit" field.
|
||||
func (b IPv6) SetHopLimit(v uint8) {
|
||||
b[hopLimit] = v
|
||||
}
|
||||
|
||||
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
|
||||
func (b IPv6) SetNextHeader(v uint8) {
|
||||
b[IPv6NextHeaderOffset] = v
|
||||
}
|
||||
|
||||
// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
|
||||
// checksum, it is empty.
|
||||
func (IPv6) SetChecksum(uint16) {
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the ipv6 header.
|
||||
func (b IPv6) Encode(i *IPv6Fields) {
|
||||
extHdr := b[IPv6MinimumSize:]
|
||||
b.SetTOS(i.TrafficClass, i.FlowLabel)
|
||||
b.SetPayloadLength(i.PayloadLength)
|
||||
b[hopLimit] = i.HopLimit
|
||||
b.SetSourceAddr(i.SrcAddr)
|
||||
b.SetDestinationAddr(i.DstAddr)
|
||||
nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
|
||||
b[IPv6NextHeaderOffset] = nextHeader
|
||||
}
|
||||
|
||||
// IsValid performs basic validation on the packet.
|
||||
func (b IPv6) IsValid(pktSize int) bool {
|
||||
if len(b) < IPv6MinimumSize {
|
||||
return false
|
||||
}
|
||||
|
||||
dlen := int(b.PayloadLength())
|
||||
if dlen > pktSize-IPv6MinimumSize {
|
||||
return false
|
||||
}
|
||||
|
||||
if IPVersion(b) != IPv6Version {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// IsV4MappedAddress determines if the provided address is an IPv4 mapped
|
||||
// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
|
||||
func IsV4MappedAddress(addr tcpip.Address) bool {
|
||||
if addr.BitLen() != IPv6AddressSizeBits {
|
||||
return false
|
||||
}
|
||||
|
||||
return IPv4MappedIPv6Subnet.Contains(addr)
|
||||
}
|
||||
|
||||
// IsV6MulticastAddress determines if the provided address is an IPv6
|
||||
// multicast address (anything starting with FF).
|
||||
func IsV6MulticastAddress(addr tcpip.Address) bool {
|
||||
if addr.BitLen() != IPv6AddressSizeBits {
|
||||
return false
|
||||
}
|
||||
return addr.As16()[0] == 0xff
|
||||
}
|
||||
|
||||
// IsV6UnicastAddress determines if the provided address is a valid IPv6
|
||||
// unicast (and specified) address. That is, IsV6UnicastAddress returns
|
||||
// true if addr contains IPv6AddressSize bytes, is not the unspecified
|
||||
// address and is not a multicast address.
|
||||
func IsV6UnicastAddress(addr tcpip.Address) bool {
|
||||
if addr.BitLen() != IPv6AddressSizeBits {
|
||||
return false
|
||||
}
|
||||
|
||||
// Must not be unspecified
|
||||
if addr == IPv6Any {
|
||||
return false
|
||||
}
|
||||
|
||||
// Return if not a multicast.
|
||||
return addr.As16()[0] != 0xff
|
||||
}
|
||||
|
||||
var solicitedNodeMulticastPrefix = [13]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff}
|
||||
|
||||
// SolicitedNodeAddr computes the solicited-node multicast address. This is
|
||||
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
|
||||
// address.
|
||||
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
|
||||
addrBytes := addr.As16()
|
||||
return tcpip.AddrFrom16([16]byte(append(solicitedNodeMulticastPrefix[:], addrBytes[len(addrBytes)-3:]...)))
|
||||
}
|
||||
|
||||
// IsSolicitedNodeAddr determines whether the address is a solicited-node
|
||||
// multicast address.
|
||||
func IsSolicitedNodeAddr(addr tcpip.Address) bool {
|
||||
addrBytes := addr.As16()
|
||||
return solicitedNodeMulticastPrefix == [13]byte(addrBytes[:len(addrBytes)-3])
|
||||
}
|
||||
|
||||
// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
|
||||
// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
|
||||
//
|
||||
// buf MUST be at least 8 bytes.
|
||||
func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
|
||||
buf[0] = linkAddr[0] ^ 2
|
||||
buf[1] = linkAddr[1]
|
||||
buf[2] = linkAddr[2]
|
||||
buf[3] = 0xFF
|
||||
buf[4] = 0xFE
|
||||
buf[5] = linkAddr[3]
|
||||
buf[6] = linkAddr[4]
|
||||
buf[7] = linkAddr[5]
|
||||
}
|
||||
|
||||
// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit
|
||||
// Ethernet/MAC address, as per RFC 4291 section 2.5.1.
|
||||
func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
|
||||
var buf [IIDSize]byte
|
||||
EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
|
||||
return buf
|
||||
}
|
||||
|
||||
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
|
||||
// (MAC) address.
|
||||
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
|
||||
// Convert a 48-bit MAC to a modified EUI-64 and then prepend the
|
||||
// link-local header, FE80::.
|
||||
//
|
||||
// The conversion is very nearly:
|
||||
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
|
||||
// Note the capital A. The conversion aa->Aa involves a bit flip.
|
||||
lladdrb := [IPv6AddressSize]byte{
|
||||
0: 0xFE,
|
||||
1: 0x80,
|
||||
}
|
||||
EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
|
||||
return tcpip.AddrFrom16(lladdrb)
|
||||
}
|
||||
|
||||
// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6
|
||||
// link-local unicast address, as defined by RFC 4291 section 2.5.6.
|
||||
func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool {
|
||||
if addr.BitLen() != IPv6AddressSizeBits {
|
||||
return false
|
||||
}
|
||||
addrBytes := addr.As16()
|
||||
return addrBytes[0] == 0xfe && (addrBytes[1]&0xc0) == 0x80
|
||||
}
|
||||
|
||||
// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback
|
||||
// address, as defined by RFC 4291 section 2.5.3.
|
||||
func IsV6LoopbackAddress(addr tcpip.Address) bool {
|
||||
return addr == IPv6Loopback
|
||||
}
|
||||
|
||||
// IsV6LinkLocalMulticastAddress returns true iff the provided address is an
|
||||
// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7.
|
||||
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
|
||||
return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope
|
||||
}
|
||||
|
||||
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
|
||||
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
|
||||
//
|
||||
// The opaque IID is generated from the cryptographic hash of the concatenation
|
||||
// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
|
||||
// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
|
||||
// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
|
||||
// requirements for security.
|
||||
//
|
||||
// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
|
||||
// array for the buffer will not be allocated.
|
||||
func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
|
||||
// As per RFC 7217 section 5, the opaque identifier can be generated as a
|
||||
// cryptographic hash of the concatenation of each of the function parameters.
|
||||
// Note, we omit the optional Network_ID field.
|
||||
h := sha256.New()
|
||||
// h.Write never returns an error.
|
||||
prefixID := prefix.ID()
|
||||
h.Write([]byte(prefixID.AsSlice()[:IIDOffsetInIPv6Address]))
|
||||
h.Write([]byte(nicName))
|
||||
h.Write([]byte{dadCounter})
|
||||
h.Write(secretKey)
|
||||
|
||||
var sumBuf [sha256.Size]byte
|
||||
sum := h.Sum(sumBuf[:0])
|
||||
|
||||
return append(buf, sum[:IIDSize]...)
|
||||
}
|
||||
|
||||
// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
|
||||
// an opaque IID.
|
||||
func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
|
||||
lladdrb := [IPv6AddressSize]byte{
|
||||
0: 0xFE,
|
||||
1: 0x80,
|
||||
}
|
||||
|
||||
return tcpip.AddrFrom16([16]byte(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)))
|
||||
}
|
||||
|
||||
// IPv6AddressScope is the scope of an IPv6 address.
|
||||
type IPv6AddressScope int
|
||||
|
||||
const (
|
||||
// LinkLocalScope indicates a link-local address.
|
||||
LinkLocalScope IPv6AddressScope = iota
|
||||
|
||||
// GlobalScope indicates a global address.
|
||||
GlobalScope
|
||||
)
|
||||
|
||||
// ScopeForIPv6Address returns the scope for an IPv6 address.
|
||||
func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
|
||||
if addr.BitLen() != IPv6AddressSizeBits {
|
||||
return GlobalScope, &tcpip.ErrBadAddress{}
|
||||
}
|
||||
|
||||
switch {
|
||||
case IsV6LinkLocalMulticastAddress(addr):
|
||||
return LinkLocalScope, nil
|
||||
|
||||
case IsV6LinkLocalUnicastAddress(addr):
|
||||
return LinkLocalScope, nil
|
||||
|
||||
default:
|
||||
return GlobalScope, nil
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an
|
||||
// associated stable/permanent SLAAC address.
|
||||
//
|
||||
// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be
|
||||
// used when generating a new temporary IID.
|
||||
//
|
||||
// Panics if tempIIDHistory is not at least IIDSize bytes.
|
||||
func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix {
|
||||
addrBytes := stableAddr.As16()
|
||||
h := sha256.New()
|
||||
h.Write(tempIIDHistory)
|
||||
h.Write(addrBytes[IIDOffsetInIPv6Address:])
|
||||
var sumBuf [sha256.Size]byte
|
||||
sum := h.Sum(sumBuf[:0])
|
||||
|
||||
// The rightmost 64 bits of sum are saved for the next iteration.
|
||||
if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
|
||||
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
|
||||
}
|
||||
|
||||
// The leftmost 64 bits of sum is used as the IID.
|
||||
if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize {
|
||||
panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize))
|
||||
}
|
||||
|
||||
return tcpip.AddressWithPrefix{
|
||||
Address: tcpip.AddrFrom16(addrBytes),
|
||||
PrefixLen: IIDOffsetInIPv6Address * 8,
|
||||
}
|
||||
}
|
||||
|
||||
// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by
|
||||
// RFC 7346 section 2.
|
||||
type IPv6MulticastScope uint8
|
||||
|
||||
// The various values for IPv6 multicast scopes, as per RFC 7346 section 2:
|
||||
//
|
||||
// +------+--------------------------+-------------------------+
|
||||
// | scop | NAME | REFERENCE |
|
||||
// +------+--------------------------+-------------------------+
|
||||
// | 0 | Reserved | [RFC4291], RFC 7346 |
|
||||
// | 1 | Interface-Local scope | [RFC4291], RFC 7346 |
|
||||
// | 2 | Link-Local scope | [RFC4291], RFC 7346 |
|
||||
// | 3 | Realm-Local scope | [RFC4291], RFC 7346 |
|
||||
// | 4 | Admin-Local scope | [RFC4291], RFC 7346 |
|
||||
// | 5 | Site-Local scope | [RFC4291], RFC 7346 |
|
||||
// | 6 | Unassigned | |
|
||||
// | 7 | Unassigned | |
|
||||
// | 8 | Organization-Local scope | [RFC4291], RFC 7346 |
|
||||
// | 9 | Unassigned | |
|
||||
// | A | Unassigned | |
|
||||
// | B | Unassigned | |
|
||||
// | C | Unassigned | |
|
||||
// | D | Unassigned | |
|
||||
// | E | Global scope | [RFC4291], RFC 7346 |
|
||||
// | F | Reserved | [RFC4291], RFC 7346 |
|
||||
// +------+--------------------------+-------------------------+
|
||||
const (
|
||||
IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0)
|
||||
IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1)
|
||||
IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2)
|
||||
IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3)
|
||||
IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4)
|
||||
IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5)
|
||||
IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8)
|
||||
IPv6GlobalMulticastScope = IPv6MulticastScope(0xE)
|
||||
IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF)
|
||||
)
|
||||
|
||||
// V6MulticastScope returns the scope of a multicast address.
|
||||
func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope {
|
||||
addrBytes := addr.As16()
|
||||
return IPv6MulticastScope(addrBytes[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask)
|
||||
}
|
||||
@@ -1,507 +0,0 @@
|
||||
// Copyright 2020 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
|
||||
type IPv6ExtensionHeaderIdentifier uint8
|
||||
|
||||
const (
|
||||
// IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
|
||||
// Hop Options extension header, as per RFC 8200 section 4.3.
|
||||
IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
|
||||
|
||||
// IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
|
||||
// header, as per RFC 8200 section 4.4.
|
||||
IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
|
||||
|
||||
// IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
|
||||
// extension header, as per RFC 8200 section 4.5.
|
||||
IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
|
||||
|
||||
// IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
|
||||
// Destination Options extension header, as per RFC 8200 section 4.6.
|
||||
IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
|
||||
|
||||
// IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
|
||||
// of an IPv6 payload, as per RFC 8200 section 4.7.
|
||||
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
|
||||
|
||||
// IPv6UnknownExtHdrIdentifier is reserved by IANA.
|
||||
// https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
|
||||
// "254 Use for experimentation and testing [RFC3692][RFC4727]"
|
||||
IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254
|
||||
)
|
||||
|
||||
const (
|
||||
// ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
|
||||
// a node encounters an unrecognized option.
|
||||
ipv6UnknownExtHdrOptionActionMask = 192
|
||||
|
||||
// ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
|
||||
// from the action value for an unrecognized option identifier.
|
||||
ipv6UnknownExtHdrOptionActionShift = 6
|
||||
|
||||
// ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
|
||||
// within an IPv6RoutingExtHdr.
|
||||
ipv6RoutingExtHdrSegmentsLeftIdx = 1
|
||||
|
||||
// IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
|
||||
// bytes.
|
||||
IPv6FragmentExtHdrLength = 8
|
||||
|
||||
// ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
|
||||
// Fragment Offset field within an IPv6FragmentExtHdr.
|
||||
ipv6FragmentExtHdrFragmentOffsetOffset = 0
|
||||
|
||||
// ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
|
||||
// Offset field within an IPv6FragmentExtHdr.
|
||||
ipv6FragmentExtHdrFragmentOffsetShift = 3
|
||||
|
||||
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
|
||||
// IPv6FragmentExtHdr.
|
||||
ipv6FragmentExtHdrFlagsIdx = 1
|
||||
|
||||
// ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
|
||||
// flags field of an IPv6FragmentExtHdr.
|
||||
ipv6FragmentExtHdrMFlagMask = 1
|
||||
|
||||
// ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
|
||||
// field within an IPv6FragmentExtHdr.
|
||||
ipv6FragmentExtHdrIdentificationOffset = 2
|
||||
|
||||
// ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
|
||||
// field. That is, given a Length field of 2, the extension header expects
|
||||
// 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
|
||||
// details about the first 8 bytes' exclusion from the Length field).
|
||||
ipv6ExtHdrLenBytesPerUnit = 8
|
||||
|
||||
// ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
|
||||
// extension header's Length field following the Length field.
|
||||
//
|
||||
// The Length field excludes the first 8 bytes, but the Next Header and Length
|
||||
// field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
|
||||
// after the Length field.
|
||||
//
|
||||
// This ensures that every extension header is at least 8 bytes.
|
||||
ipv6ExtHdrLenBytesExcluded = 6
|
||||
|
||||
// IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
|
||||
// extension header's Fragment Offset field. That is, given a Fragment Offset
|
||||
// of 2, the extension header is indicating that the fragment's payload
|
||||
// starts at the 16th byte in the reassembled packet.
|
||||
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
|
||||
)
|
||||
|
||||
// padIPv6OptionsLength returns the total length for IPv6 options of length l
|
||||
// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
|
||||
func padIPv6OptionsLength(length int) int {
|
||||
return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
|
||||
}
|
||||
|
||||
// padIPv6Option fills b with the appropriate padding options depending on its
|
||||
// length.
|
||||
func padIPv6Option(b []byte) {
|
||||
switch len(b) {
|
||||
case 0: // No padding needed.
|
||||
case 1: // Pad with Pad1.
|
||||
b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
|
||||
default: // Pad with PadN.
|
||||
s := b[ipv6ExtHdrOptionPayloadOffset:]
|
||||
common.ClearArray(s)
|
||||
b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
|
||||
b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
|
||||
}
|
||||
}
|
||||
|
||||
// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
|
||||
// serialize an option at headerOffset with alignment requirements
|
||||
// [align]n + alignOffset.
|
||||
func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
|
||||
padLen := headerOffset - alignOffset
|
||||
return ((padLen + align - 1) & ^(align - 1)) - padLen
|
||||
}
|
||||
|
||||
// IPv6OptionUnknownAction is the action that must be taken if the processing
|
||||
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
|
||||
type IPv6OptionUnknownAction int
|
||||
|
||||
const (
|
||||
// IPv6OptionUnknownActionSkip indicates that the unrecognized option must
|
||||
// be skipped and the node should continue processing the header.
|
||||
IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
|
||||
|
||||
// IPv6OptionUnknownActionDiscard indicates that the packet must be silently
|
||||
// discarded.
|
||||
IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
|
||||
|
||||
// IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
|
||||
// discarded and the node must send an ICMP Parameter Problem, Code 2, message
|
||||
// to the packet's source, regardless of whether or not the packet's
|
||||
// Destination was a multicast address.
|
||||
IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
|
||||
|
||||
// IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
|
||||
// packet must be discarded and the node must send an ICMP Parameter Problem,
|
||||
// Code 2, message to the packet's source only if the packet's Destination was
|
||||
// not a multicast address.
|
||||
IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
|
||||
)
|
||||
|
||||
// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
|
||||
type IPv6ExtHdrOption interface {
|
||||
// UnknownAction returns the action to take in response to an unrecognized
|
||||
// option.
|
||||
UnknownAction() IPv6OptionUnknownAction
|
||||
|
||||
// isIPv6ExtHdrOption is used to "lock" this interface so it is not
|
||||
// implemented by other packages.
|
||||
isIPv6ExtHdrOption()
|
||||
}
|
||||
|
||||
// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
|
||||
type IPv6ExtHdrOptionIdentifier uint8
|
||||
|
||||
const (
|
||||
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
|
||||
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
|
||||
ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
|
||||
|
||||
// ipv6PadNExtHdrOptionIdentifier is the identifier for a padding option that
|
||||
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
|
||||
ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
|
||||
|
||||
// ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
|
||||
// Alert Hop by Hop option as defined in RFC 2711 section 2.1.
|
||||
ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
|
||||
|
||||
// ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
|
||||
// option as defined in RFC 8200 section 4.2.
|
||||
ipv6ExtHdrOptionTypeOffset = 0
|
||||
|
||||
// ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
|
||||
// header option as defined in RFC 8200 section 4.2.
|
||||
ipv6ExtHdrOptionLengthOffset = 1
|
||||
|
||||
// ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
|
||||
// header option as defined in RFC 8200 section 4.2.
|
||||
ipv6ExtHdrOptionPayloadOffset = 2
|
||||
)
|
||||
|
||||
// ipv6UnknownActionFromIdentifier maps an extension header option's
|
||||
// identifier's high bits to the action to take when the identifier is unknown.
|
||||
func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
|
||||
return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
|
||||
}
|
||||
|
||||
// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
|
||||
// is malformed.
|
||||
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
|
||||
|
||||
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
|
||||
// data as outlined in RFC 8200 section 4.5.
|
||||
//
|
||||
// Note, the buffer does not include the Next Header and Reserved fields.
|
||||
type IPv6FragmentExtHdr [6]byte
|
||||
|
||||
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
|
||||
func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
|
||||
|
||||
// Release implements IPv6PayloadHeader.Release.
|
||||
func (IPv6FragmentExtHdr) Release() {}
|
||||
|
||||
// FragmentOffset returns the Fragment Offset field.
|
||||
//
|
||||
// This value indicates where the buffer following the Fragment extension header
|
||||
// starts in the target (reassembled) packet.
|
||||
func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
|
||||
return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
|
||||
}
|
||||
|
||||
// More returns the More (M) flag.
|
||||
//
|
||||
// This indicates whether any fragments are expected to succeed b.
|
||||
func (b IPv6FragmentExtHdr) More() bool {
|
||||
return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
|
||||
}
|
||||
|
||||
// ID returns the Identification field.
|
||||
//
|
||||
// This value is used to uniquely identify the packet, between a
|
||||
// source and destination.
|
||||
func (b IPv6FragmentExtHdr) ID() uint32 {
|
||||
return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
|
||||
}
|
||||
|
||||
// IsAtomic returns whether the fragment header indicates an atomic fragment. An
|
||||
// atomic fragment is a fragment that contains all the data required to
|
||||
// reassemble a full packet.
|
||||
func (b IPv6FragmentExtHdr) IsAtomic() bool {
|
||||
return !b.More() && b.FragmentOffset() == 0
|
||||
}
|
||||
|
||||
// IPv6SerializableExtHdr provides serialization for IPv6 extension
|
||||
// headers.
|
||||
type IPv6SerializableExtHdr interface {
|
||||
// identifier returns the assigned IPv6 header identifier for this extension
|
||||
// header.
|
||||
identifier() IPv6ExtensionHeaderIdentifier
|
||||
|
||||
// length returns the total serialized length in bytes of this extension
|
||||
// header, including the common next header and length fields.
|
||||
length() int
|
||||
|
||||
// serializeInto serializes the receiver into the provided byte
|
||||
// buffer and with the provided nextHeader value.
|
||||
//
|
||||
// Note, the caller MUST provide a byte buffer with size of at least
|
||||
// length. Implementers of this function may assume that the byte buffer
|
||||
// is of sufficient size. serializeInto MAY panic if the provided byte
|
||||
// buffer is not of sufficient size.
|
||||
//
|
||||
// serializeInto returns the number of bytes that was used to serialize the
|
||||
// receiver. Implementers must only use the number of bytes required to
|
||||
// serialize the receiver. Callers MAY provide a larger buffer than required
|
||||
// to serialize into.
|
||||
serializeInto(nextHeader uint8, b []byte) int
|
||||
}
|
||||
|
||||
var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
|
||||
|
||||
// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
|
||||
// options extension header.
|
||||
type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
|
||||
|
||||
const (
|
||||
// ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
|
||||
// in a hop by hop extension header as defined in RFC 8200 section 4.3.
|
||||
ipv6HopByHopExtHdrNextHeaderOffset = 0
|
||||
|
||||
// ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
|
||||
// by hop extension header as defined in RFC 8200 section 4.3.
|
||||
ipv6HopByHopExtHdrLengthOffset = 1
|
||||
|
||||
// ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
|
||||
// hop extension header as defined in RFC 8200 section 4.3.
|
||||
ipv6HopByHopExtHdrOptionsOffset = 2
|
||||
|
||||
// ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
|
||||
// words in a hop by hop extension header's length field, as stated in RFC
|
||||
// 8200 section 4.3:
|
||||
// Length of the Hop-by-Hop Options header in 8-octet units,
|
||||
// not including the first 8 octets.
|
||||
ipv6HopByHopExtHdrUnaccountedLenWords = 1
|
||||
)
|
||||
|
||||
// identifier implements IPv6SerializableExtHdr.
|
||||
func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
|
||||
return IPv6HopByHopOptionsExtHdrIdentifier
|
||||
}
|
||||
|
||||
// length implements IPv6SerializableExtHdr.
|
||||
func (h IPv6SerializableHopByHopExtHdr) length() int {
|
||||
var total int
|
||||
for _, opt := range h {
|
||||
align, alignOffset := opt.alignment()
|
||||
total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
|
||||
total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
|
||||
}
|
||||
// Account for next header and total length fields and add padding.
|
||||
return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
|
||||
}
|
||||
|
||||
// serializeInto implements IPv6SerializableExtHdr.
|
||||
func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
|
||||
optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
|
||||
totalLength := ipv6HopByHopExtHdrOptionsOffset
|
||||
for _, opt := range h {
|
||||
// Calculate alignment requirements and pad buffer if necessary.
|
||||
align, alignOffset := opt.alignment()
|
||||
padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
|
||||
if padLen != 0 {
|
||||
padIPv6Option(optBuffer[:padLen])
|
||||
totalLength += padLen
|
||||
optBuffer = optBuffer[padLen:]
|
||||
}
|
||||
|
||||
l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
|
||||
optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
|
||||
optBuffer[ipv6ExtHdrOptionLengthOffset] = l
|
||||
l += ipv6ExtHdrOptionPayloadOffset
|
||||
totalLength += int(l)
|
||||
optBuffer = optBuffer[l:]
|
||||
}
|
||||
padded := padIPv6OptionsLength(totalLength)
|
||||
if padded != totalLength {
|
||||
padIPv6Option(optBuffer[:padded-totalLength])
|
||||
totalLength = padded
|
||||
}
|
||||
wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
|
||||
if wordsLen > math.MaxUint8 {
|
||||
panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
|
||||
}
|
||||
b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
|
||||
b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
|
||||
return totalLength
|
||||
}
|
||||
|
||||
// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
|
||||
type IPv6SerializableHopByHopOption interface {
|
||||
// identifier returns the option identifier of this Hop by Hop option.
|
||||
identifier() IPv6ExtHdrOptionIdentifier
|
||||
|
||||
// length returns the *payload* size of the option (not considering the type
|
||||
// and length fields).
|
||||
length() uint8
|
||||
|
||||
// alignment returns the alignment requirements from this option.
|
||||
//
|
||||
// Alignment requirements take the form [align]n + offset as specified in
|
||||
// RFC 8200 section 4.2. The alignment requirement is on the offset between
|
||||
// the option type byte and the start of the hop by hop header.
|
||||
//
|
||||
// align must be a power of 2.
|
||||
alignment() (align int, offset int)
|
||||
|
||||
// serializeInto serializes the receiver into the provided byte
|
||||
// buffer.
|
||||
//
|
||||
// Note, the caller MUST provide a byte buffer with size of at least
|
||||
// length. Implementers of this function may assume that the byte buffer
|
||||
// is of sufficient size. serializeInto MAY panic if the provided byte
|
||||
// buffer is not of sufficient size.
|
||||
//
|
||||
// serializeInto will return the number of bytes that was used to
|
||||
// serialize the receiver. Implementers must only use the number of
|
||||
// bytes required to serialize the receiver. Callers MAY provide a
|
||||
// larger buffer than required to serialize into.
|
||||
serializeInto([]byte) uint8
|
||||
}
|
||||
|
||||
var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
|
||||
|
||||
// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
|
||||
// RFC 2711 section 2.1.
|
||||
type IPv6RouterAlertOption struct {
|
||||
Value IPv6RouterAlertValue
|
||||
}
|
||||
|
||||
// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
|
||||
type IPv6RouterAlertValue uint16
|
||||
|
||||
const (
|
||||
// IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
|
||||
// Discovery message as defined in RFC 2711 section 2.1.
|
||||
IPv6RouterAlertMLD IPv6RouterAlertValue = 0
|
||||
// IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
|
||||
// defined in RFC 2711 section 2.1.
|
||||
IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
|
||||
// IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
|
||||
// Networks message as defined in RFC 2711 section 2.1.
|
||||
IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
|
||||
|
||||
// ipv6RouterAlertPayloadLength is the length of the Router Alert payload
|
||||
// as defined in RFC 2711.
|
||||
ipv6RouterAlertPayloadLength = 2
|
||||
|
||||
// ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
|
||||
// Router Alert option defined as 2n+0 in RFC 2711.
|
||||
ipv6RouterAlertAlignmentRequirement = 2
|
||||
|
||||
// ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
|
||||
// requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
|
||||
// 2.1.
|
||||
ipv6RouterAlertAlignmentOffsetRequirement = 0
|
||||
)
|
||||
|
||||
// UnknownAction implements IPv6ExtHdrOption.
|
||||
func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
|
||||
return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
|
||||
}
|
||||
|
||||
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
|
||||
func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
|
||||
|
||||
// identifier implements IPv6SerializableHopByHopOption.
|
||||
func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
|
||||
return ipv6RouterAlertHopByHopOptionIdentifier
|
||||
}
|
||||
|
||||
// length implements IPv6SerializableHopByHopOption.
|
||||
func (*IPv6RouterAlertOption) length() uint8 {
|
||||
return ipv6RouterAlertPayloadLength
|
||||
}
|
||||
|
||||
// alignment implements IPv6SerializableHopByHopOption.
|
||||
func (*IPv6RouterAlertOption) alignment() (int, int) {
|
||||
// From RFC 2711 section 2.1:
|
||||
// Alignment requirement: 2n+0.
|
||||
return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
|
||||
}
|
||||
|
||||
// serializeInto implements IPv6SerializableHopByHopOption.
|
||||
func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
|
||||
binary.BigEndian.PutUint16(b, uint16(o.Value))
|
||||
return ipv6RouterAlertPayloadLength
|
||||
}
|
||||
|
||||
// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
|
||||
type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
|
||||
|
||||
// Serialize serializes the provided list of IPv6 extension headers into b.
|
||||
//
|
||||
// Note, b must be of sufficient size to hold all the headers in s. See
|
||||
// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
|
||||
// serialized IPv6ExtHdrSerializer.
|
||||
//
|
||||
// Serialize may panic if b is not of sufficient size to hold all the options
|
||||
// in s.
|
||||
//
|
||||
// Serialize takes the transportProtocol value to be used as the last extension
|
||||
// header's Next Header value and returns the header identifier of the first
|
||||
// serialized extension header and the total serialized length.
|
||||
func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
|
||||
nextHeader := uint8(transportProtocol)
|
||||
if len(s) == 0 {
|
||||
return nextHeader, 0
|
||||
}
|
||||
var totalLength int
|
||||
for i, h := range s[:len(s)-1] {
|
||||
length := h.serializeInto(uint8(s[i+1].identifier()), b)
|
||||
b = b[length:]
|
||||
totalLength += length
|
||||
}
|
||||
totalLength += s[len(s)-1].serializeInto(nextHeader, b)
|
||||
return uint8(s[0].identifier()), totalLength
|
||||
}
|
||||
|
||||
// Length returns the total number of bytes required to serialize the extension
|
||||
// headers.
|
||||
func (s IPv6ExtHdrSerializer) Length() int {
|
||||
var totalLength int
|
||||
for _, h := range s {
|
||||
totalLength += h.length()
|
||||
}
|
||||
return totalLength
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
)
|
||||
|
||||
const (
|
||||
nextHdrFrag = 0
|
||||
fragOff = 2
|
||||
more = 3
|
||||
idV6 = 4
|
||||
)
|
||||
|
||||
var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
|
||||
|
||||
// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
|
||||
// extension header as defined in RFC 8200 section 4.5.
|
||||
type IPv6SerializableFragmentExtHdr struct {
|
||||
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
|
||||
FragmentOffset uint16
|
||||
|
||||
// M is the "more" field of an IPv6 fragment.
|
||||
M bool
|
||||
|
||||
// Identification is the "identification" field of an IPv6 fragment.
|
||||
Identification uint32
|
||||
}
|
||||
|
||||
// identifier implements IPv6SerializableFragmentExtHdr.
|
||||
func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
|
||||
return IPv6FragmentHeader
|
||||
}
|
||||
|
||||
// length implements IPv6SerializableFragmentExtHdr.
|
||||
func (h *IPv6SerializableFragmentExtHdr) length() int {
|
||||
return IPv6FragmentHeaderSize
|
||||
}
|
||||
|
||||
// serializeInto implements IPv6SerializableFragmentExtHdr.
|
||||
func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
|
||||
// Prevent too many bounds checks.
|
||||
_ = b[IPv6FragmentHeaderSize:]
|
||||
binary.BigEndian.PutUint32(b[idV6:], h.Identification)
|
||||
binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
|
||||
b[nextHdrFrag] = nextHeader
|
||||
if h.M {
|
||||
b[more] |= ipv6FragmentExtHdrMFlagMask
|
||||
}
|
||||
return IPv6FragmentHeaderSize
|
||||
}
|
||||
|
||||
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
|
||||
// Most of the methods of IPv6Fragment access to the underlying slice without
|
||||
// checking the boundaries and could panic because of 'index out of range'.
|
||||
// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
|
||||
type IPv6Fragment []byte
|
||||
|
||||
const (
|
||||
// IPv6FragmentHeader header is the number used to specify that the next
|
||||
// header is a fragment header, per RFC 2460.
|
||||
IPv6FragmentHeader = 44
|
||||
|
||||
// IPv6FragmentHeaderSize is the size of the fragment header.
|
||||
IPv6FragmentHeaderSize = 8
|
||||
)
|
||||
|
||||
// IsValid performs basic validation on the fragment header.
|
||||
func (b IPv6Fragment) IsValid() bool {
|
||||
return len(b) >= IPv6FragmentHeaderSize
|
||||
}
|
||||
|
||||
// NextHeader returns the value of the "next header" field of the ipv6 fragment.
|
||||
func (b IPv6Fragment) NextHeader() uint8 {
|
||||
return b[nextHdrFrag]
|
||||
}
|
||||
|
||||
// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
|
||||
func (b IPv6Fragment) FragmentOffset() uint16 {
|
||||
return binary.BigEndian.Uint16(b[fragOff:]) >> 3
|
||||
}
|
||||
|
||||
// More returns the "more" field of the ipv6 fragment.
|
||||
func (b IPv6Fragment) More() bool {
|
||||
return b[more]&1 > 0
|
||||
}
|
||||
|
||||
// Payload implements Network.Payload.
|
||||
func (b IPv6Fragment) Payload() []byte {
|
||||
return b[IPv6FragmentHeaderSize:]
|
||||
}
|
||||
|
||||
// ID returns the value of the identifier field of the ipv6 fragment.
|
||||
func (b IPv6Fragment) ID() uint32 {
|
||||
return binary.BigEndian.Uint32(b[idV6:])
|
||||
}
|
||||
|
||||
// TransportProtocol implements Network.TransportProtocol.
|
||||
func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
|
||||
return tcpip.TransportProtocolNumber(b.NextHeader())
|
||||
}
|
||||
|
||||
// The functions below have been added only to satisfy the Network interface.
|
||||
|
||||
// Checksum is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) Checksum() uint16 {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// SourceAddress is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) SourceAddress() tcpip.Address {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// DestinationAddress is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) DestinationAddress() tcpip.Address {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// SetSourceAddress is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// SetDestinationAddress is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// SetChecksum is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) SetChecksum(uint16) {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// TOS is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) TOS() (uint8, uint32) {
|
||||
panic("not supported")
|
||||
}
|
||||
|
||||
// SetTOS is not supported by IPv6Fragment.
|
||||
func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
|
||||
panic("not supported")
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
// Copyright 2019 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import "github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
|
||||
// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
|
||||
// only contain the body of an ICMPv6 packet.
|
||||
//
|
||||
// See RFC 4861 section 4.4 for more details.
|
||||
type NDPNeighborAdvert []byte
|
||||
|
||||
const (
|
||||
// NDPNAMinimumSize is the minimum size of a valid NDP Neighbor
|
||||
// Advertisement message (body of an ICMPv6 packet).
|
||||
NDPNAMinimumSize = 20
|
||||
|
||||
// ndpNATargetAddressOffset is the start of the Target Address
|
||||
// field within an NDPNeighborAdvert.
|
||||
ndpNATargetAddressOffset = 4
|
||||
|
||||
// ndpNAOptionsOffset is the start of the NDP options in an
|
||||
// NDPNeighborAdvert.
|
||||
ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize
|
||||
|
||||
// ndpNAFlagsOffset is the offset of the flags within an
|
||||
// NDPNeighborAdvert
|
||||
ndpNAFlagsOffset = 0
|
||||
|
||||
// ndpNARouterFlagMask is the mask of the Router Flag field in
|
||||
// the flags byte within in an NDPNeighborAdvert.
|
||||
ndpNARouterFlagMask = (1 << 7)
|
||||
|
||||
// ndpNASolicitedFlagMask is the mask of the Solicited Flag field in
|
||||
// the flags byte within in an NDPNeighborAdvert.
|
||||
ndpNASolicitedFlagMask = (1 << 6)
|
||||
|
||||
// ndpNAOverrideFlagMask is the mask of the Override Flag field in
|
||||
// the flags byte within in an NDPNeighborAdvert.
|
||||
ndpNAOverrideFlagMask = (1 << 5)
|
||||
)
|
||||
|
||||
// TargetAddress returns the value within the Target Address field.
|
||||
func (b NDPNeighborAdvert) TargetAddress() tcpip.Address {
|
||||
return tcpip.AddrFrom16Slice(b[ndpNATargetAddressOffset:][:IPv6AddressSize])
|
||||
}
|
||||
|
||||
// SetTargetAddress sets the value within the Target Address field.
|
||||
func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) {
|
||||
copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
// RouterFlag returns the value of the Router Flag field.
|
||||
func (b NDPNeighborAdvert) RouterFlag() bool {
|
||||
return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0
|
||||
}
|
||||
|
||||
// SetRouterFlag sets the value in the Router Flag field.
|
||||
func (b NDPNeighborAdvert) SetRouterFlag(f bool) {
|
||||
if f {
|
||||
b[ndpNAFlagsOffset] |= ndpNARouterFlagMask
|
||||
} else {
|
||||
b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask
|
||||
}
|
||||
}
|
||||
|
||||
// SolicitedFlag returns the value of the Solicited Flag field.
|
||||
func (b NDPNeighborAdvert) SolicitedFlag() bool {
|
||||
return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0
|
||||
}
|
||||
|
||||
// SetSolicitedFlag sets the value in the Solicited Flag field.
|
||||
func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) {
|
||||
if f {
|
||||
b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask
|
||||
} else {
|
||||
b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask
|
||||
}
|
||||
}
|
||||
|
||||
// OverrideFlag returns the value of the Override Flag field.
|
||||
func (b NDPNeighborAdvert) OverrideFlag() bool {
|
||||
return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0
|
||||
}
|
||||
|
||||
// SetOverrideFlag sets the value in the Override Flag field.
|
||||
func (b NDPNeighborAdvert) SetOverrideFlag(f bool) {
|
||||
if f {
|
||||
b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask
|
||||
} else {
|
||||
b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask
|
||||
}
|
||||
}
|
||||
|
||||
// Options returns an NDPOptions of the options body.
|
||||
func (b NDPNeighborAdvert) Options() NDPOptions {
|
||||
return NDPOptions(b[ndpNAOptionsOffset:])
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright 2019 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import "github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
|
||||
// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
|
||||
// contain the body of an ICMPv6 packet.
|
||||
//
|
||||
// See RFC 4861 section 4.3 for more details.
|
||||
type NDPNeighborSolicit []byte
|
||||
|
||||
const (
|
||||
// NDPNSMinimumSize is the minimum size of a valid NDP Neighbor
|
||||
// Solicitation message (body of an ICMPv6 packet).
|
||||
NDPNSMinimumSize = 20
|
||||
|
||||
// ndpNSTargetAddessOffset is the start of the Target Address
|
||||
// field within an NDPNeighborSolicit.
|
||||
ndpNSTargetAddessOffset = 4
|
||||
|
||||
// ndpNSOptionsOffset is the start of the NDP options in an
|
||||
// NDPNeighborSolicit.
|
||||
ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize
|
||||
)
|
||||
|
||||
// TargetAddress returns the value within the Target Address field.
|
||||
func (b NDPNeighborSolicit) TargetAddress() tcpip.Address {
|
||||
return tcpip.AddrFrom16Slice(b[ndpNSTargetAddessOffset:][:IPv6AddressSize])
|
||||
}
|
||||
|
||||
// SetTargetAddress sets the value within the Target Address field.
|
||||
func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) {
|
||||
copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
// Options returns an NDPOptions of the options body.
|
||||
func (b NDPNeighborSolicit) Options() NDPOptions {
|
||||
return NDPOptions(b[ndpNSOptionsOffset:])
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,204 +0,0 @@
|
||||
// Copyright 2019 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ fmt.Stringer = NDPRoutePreference(0)
|
||||
|
||||
// NDPRoutePreference is the preference values for default routers or
|
||||
// more-specific routes.
|
||||
//
|
||||
// As per RFC 4191 section 2.1,
|
||||
//
|
||||
// Default router preferences and preferences for more-specific routes
|
||||
// are encoded the same way.
|
||||
//
|
||||
// Preference values are encoded as a two-bit signed integer, as
|
||||
// follows:
|
||||
//
|
||||
// 01 High
|
||||
// 00 Medium (default)
|
||||
// 11 Low
|
||||
// 10 Reserved - MUST NOT be sent
|
||||
//
|
||||
// Note that implementations can treat the value as a two-bit signed
|
||||
// integer.
|
||||
//
|
||||
// Having just three values reinforces that they are not metrics and
|
||||
// more values do not appear to be necessary for reasonable scenarios.
|
||||
type NDPRoutePreference uint8
|
||||
|
||||
const (
|
||||
// HighRoutePreference indicates a high preference, as per
|
||||
// RFC 4191 section 2.1.
|
||||
HighRoutePreference NDPRoutePreference = 0b01
|
||||
|
||||
// MediumRoutePreference indicates a medium preference, as per
|
||||
// RFC 4191 section 2.1.
|
||||
//
|
||||
// This is the default preference value.
|
||||
MediumRoutePreference = 0b00
|
||||
|
||||
// LowRoutePreference indicates a low preference, as per
|
||||
// RFC 4191 section 2.1.
|
||||
LowRoutePreference = 0b11
|
||||
|
||||
// ReservedRoutePreference is a reserved preference value, as per
|
||||
// RFC 4191 section 2.1.
|
||||
//
|
||||
// It MUST NOT be sent.
|
||||
ReservedRoutePreference = 0b10
|
||||
)
|
||||
|
||||
// String implements fmt.Stringer.
|
||||
func (p NDPRoutePreference) String() string {
|
||||
switch p {
|
||||
case HighRoutePreference:
|
||||
return "HighRoutePreference"
|
||||
case MediumRoutePreference:
|
||||
return "MediumRoutePreference"
|
||||
case LowRoutePreference:
|
||||
return "LowRoutePreference"
|
||||
case ReservedRoutePreference:
|
||||
return "ReservedRoutePreference"
|
||||
default:
|
||||
return fmt.Sprintf("NDPRoutePreference(%d)", p)
|
||||
}
|
||||
}
|
||||
|
||||
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
|
||||
// the body of an ICMPv6 packet.
|
||||
//
|
||||
// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
|
||||
type NDPRouterAdvert []byte
|
||||
|
||||
// As per RFC 4191 section 2.2,
|
||||
//
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Type | Code | Checksum |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Reachable Time |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Retrans Timer |
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|
||||
// | Options ...
|
||||
// +-+-+-+-+-+-+-+-+-+-+-+-
|
||||
const (
|
||||
// NDPRAMinimumSize is the minimum size of a valid NDP Router
|
||||
// Advertisement message (body of an ICMPv6 packet).
|
||||
NDPRAMinimumSize = 12
|
||||
|
||||
// ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
|
||||
// within an NDPRouterAdvert.
|
||||
ndpRACurrHopLimitOffset = 0
|
||||
|
||||
// ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
|
||||
// within an NDPRouterAdvert.
|
||||
ndpRAFlagsOffset = 1
|
||||
|
||||
// ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
|
||||
// Configuration flag within the bit-field/flags byte of an
|
||||
// NDPRouterAdvert.
|
||||
ndpRAManagedAddrConfFlagMask = (1 << 7)
|
||||
|
||||
// ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
|
||||
// within the bit-field/flags byte of an NDPRouterAdvert.
|
||||
ndpRAOtherConfFlagMask = (1 << 6)
|
||||
|
||||
// ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
|
||||
// Preference) field within the flags byte of an NDPRouterAdvert.
|
||||
ndpDefaultRouterPreferenceShift = 3
|
||||
|
||||
// ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
|
||||
// Preference) field within the flags byte of an NDPRouterAdvert.
|
||||
ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
|
||||
|
||||
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
|
||||
// field within an NDPRouterAdvert.
|
||||
ndpRARouterLifetimeOffset = 2
|
||||
|
||||
// ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
|
||||
// field within an NDPRouterAdvert.
|
||||
ndpRAReachableTimeOffset = 4
|
||||
|
||||
// ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
|
||||
// field within an NDPRouterAdvert.
|
||||
ndpRARetransTimerOffset = 8
|
||||
|
||||
// ndpRAOptionsOffset is the start of the NDP options in an
|
||||
// NDPRouterAdvert.
|
||||
ndpRAOptionsOffset = 12
|
||||
)
|
||||
|
||||
// CurrHopLimit returns the value of the Curr Hop Limit field.
|
||||
func (b NDPRouterAdvert) CurrHopLimit() uint8 {
|
||||
return b[ndpRACurrHopLimitOffset]
|
||||
}
|
||||
|
||||
// ManagedAddrConfFlag returns the value of the Managed Address Configuration
|
||||
// flag.
|
||||
func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
|
||||
return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
|
||||
}
|
||||
|
||||
// OtherConfFlag returns the value of the Other Configuration flag.
|
||||
func (b NDPRouterAdvert) OtherConfFlag() bool {
|
||||
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
|
||||
}
|
||||
|
||||
// DefaultRouterPreference returns the Default Router Preference field.
|
||||
func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
|
||||
return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
|
||||
}
|
||||
|
||||
// RouterLifetime returns the lifetime associated with the default router. A
|
||||
// value of 0 means the source of the Router Advertisement is not a default
|
||||
// router and SHOULD NOT appear on the default router list. Note, a value of 0
|
||||
// only means that the router should not be used as a default router, it does
|
||||
// not apply to other information contained in the Router Advertisement.
|
||||
func (b NDPRouterAdvert) RouterLifetime() time.Duration {
|
||||
// The field is the time in seconds, as per RFC 4861 section 4.2.
|
||||
return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
|
||||
}
|
||||
|
||||
// ReachableTime returns the time that a node assumes a neighbor is reachable
|
||||
// after having received a reachability confirmation. A value of 0 means
|
||||
// that it is unspecified by the source of the Router Advertisement message.
|
||||
func (b NDPRouterAdvert) ReachableTime() time.Duration {
|
||||
// The field is the time in milliseconds, as per RFC 4861 section 4.2.
|
||||
return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
|
||||
}
|
||||
|
||||
// RetransTimer returns the time between retransmitted Neighbor Solicitation
|
||||
// messages. A value of 0 means that it is unspecified by the source of the
|
||||
// Router Advertisement message.
|
||||
func (b NDPRouterAdvert) RetransTimer() time.Duration {
|
||||
// The field is the time in milliseconds, as per RFC 4861 section 4.2.
|
||||
return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
|
||||
}
|
||||
|
||||
// Options returns an NDPOptions of the options body.
|
||||
func (b NDPRouterAdvert) Options() NDPOptions {
|
||||
return NDPOptions(b[ndpRAOptionsOffset:])
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
// Copyright 2019 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
|
||||
// the body of an ICMPv6 packet.
|
||||
//
|
||||
// See RFC 4861 section 4.1 for more details.
|
||||
type NDPRouterSolicit []byte
|
||||
|
||||
const (
|
||||
// NDPRSMinimumSize is the minimum size of a valid NDP Router
|
||||
// Solicitation message (body of an ICMPv6 packet).
|
||||
NDPRSMinimumSize = 4
|
||||
|
||||
// ndpRSOptionsOffset is the start of the NDP options in an
|
||||
// NDPRouterSolicit.
|
||||
ndpRSOptionsOffset = 4
|
||||
)
|
||||
|
||||
// Options returns an NDPOptions of the options body.
|
||||
func (b NDPRouterSolicit) Options() NDPOptions {
|
||||
return NDPOptions(b[ndpRSOptionsOffset:])
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
// Copyright 2020 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT.
|
||||
|
||||
package header
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[ndpSourceLinkLayerAddressOptionType-1]
|
||||
_ = x[ndpTargetLinkLayerAddressOptionType-2]
|
||||
_ = x[ndpPrefixInformationType-3]
|
||||
_ = x[ndpNonceOptionType-14]
|
||||
_ = x[ndpRecursiveDNSServerOptionType-25]
|
||||
_ = x[ndpDNSSearchListOptionType-31]
|
||||
}
|
||||
|
||||
const (
|
||||
_ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType"
|
||||
_ndpOptionIdentifier_name_1 = "ndpNonceOptionType"
|
||||
_ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType"
|
||||
_ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType"
|
||||
)
|
||||
|
||||
var (
|
||||
_ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
|
||||
)
|
||||
|
||||
func (i ndpOptionIdentifier) String() string {
|
||||
switch {
|
||||
case 1 <= i && i <= 3:
|
||||
i -= 1
|
||||
return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]]
|
||||
case i == 14:
|
||||
return _ndpOptionIdentifier_name_1
|
||||
case i == 25:
|
||||
return _ndpOptionIdentifier_name_2
|
||||
case i == 31:
|
||||
return _ndpOptionIdentifier_name_3
|
||||
default:
|
||||
return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
package header
|
||||
|
||||
import "net/netip"
|
||||
|
||||
func (b IPv4) SourceAddr() netip.Addr {
|
||||
return netip.AddrFrom4([4]byte(b[srcAddr : srcAddr+IPv4AddressSize]))
|
||||
}
|
||||
|
||||
func (b IPv4) DestinationAddr() netip.Addr {
|
||||
return netip.AddrFrom4([4]byte(b[dstAddr : dstAddr+IPv4AddressSize]))
|
||||
}
|
||||
|
||||
func (b IPv4) SetSourceAddr(addr netip.Addr) {
|
||||
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
func (b IPv4) SetDestinationAddr(addr netip.Addr) {
|
||||
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
func (b IPv6) SourceAddr() netip.Addr {
|
||||
return netip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize]))
|
||||
}
|
||||
|
||||
func (b IPv6) DestinationAddr() netip.Addr {
|
||||
return netip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize]))
|
||||
}
|
||||
|
||||
func (b IPv6) SetSourceAddr(addr netip.Addr) {
|
||||
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
|
||||
func (b IPv6) SetDestinationAddr(addr netip.Addr) {
|
||||
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
|
||||
}
|
||||
@@ -1,727 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/seqnum"
|
||||
|
||||
"github.com/google/btree"
|
||||
)
|
||||
|
||||
// These constants are the offsets of the respective fields in the TCP header.
|
||||
const (
|
||||
TCPSrcPortOffset = 0
|
||||
TCPDstPortOffset = 2
|
||||
TCPSeqNumOffset = 4
|
||||
TCPAckNumOffset = 8
|
||||
TCPDataOffset = 12
|
||||
TCPFlagsOffset = 13
|
||||
TCPWinSizeOffset = 14
|
||||
TCPChecksumOffset = 16
|
||||
TCPUrgentPtrOffset = 18
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxWndScale is maximum allowed window scaling, as described in
|
||||
// RFC 1323, section 2.3, page 11.
|
||||
MaxWndScale = 14
|
||||
|
||||
// TCPMaxSACKBlocks is the maximum number of SACK blocks that can
|
||||
// be encoded in a TCP option field.
|
||||
TCPMaxSACKBlocks = 4
|
||||
)
|
||||
|
||||
// TCPFlags is the dedicated type for TCP flags.
|
||||
type TCPFlags uint8
|
||||
|
||||
// Intersects returns true iff there are flags common to both f and o.
|
||||
func (f TCPFlags) Intersects(o TCPFlags) bool {
|
||||
return f&o != 0
|
||||
}
|
||||
|
||||
// Contains returns true iff all the flags in o are contained within f.
|
||||
func (f TCPFlags) Contains(o TCPFlags) bool {
|
||||
return f&o == o
|
||||
}
|
||||
|
||||
// String implements Stringer.String.
|
||||
func (f TCPFlags) String() string {
|
||||
flagsStr := []byte("FSRPAUEC")
|
||||
for i := range flagsStr {
|
||||
if f&(1<<uint(i)) == 0 {
|
||||
flagsStr[i] = ' '
|
||||
}
|
||||
}
|
||||
return string(flagsStr)
|
||||
}
|
||||
|
||||
// Flags that may be set in a TCP segment.
|
||||
const (
|
||||
TCPFlagFin TCPFlags = 1 << iota
|
||||
TCPFlagSyn
|
||||
TCPFlagRst
|
||||
TCPFlagPsh
|
||||
TCPFlagAck
|
||||
TCPFlagUrg
|
||||
TCPFlagEce
|
||||
TCPFlagCwr
|
||||
)
|
||||
|
||||
// Options that may be present in a TCP segment.
|
||||
const (
|
||||
TCPOptionEOL = 0
|
||||
TCPOptionNOP = 1
|
||||
TCPOptionMSS = 2
|
||||
TCPOptionWS = 3
|
||||
TCPOptionTS = 8
|
||||
TCPOptionSACKPermitted = 4
|
||||
TCPOptionSACK = 5
|
||||
)
|
||||
|
||||
// Option Lengths.
|
||||
const (
|
||||
TCPOptionMSSLength = 4
|
||||
TCPOptionTSLength = 10
|
||||
TCPOptionWSLength = 3
|
||||
TCPOptionSackPermittedLength = 2
|
||||
)
|
||||
|
||||
// TCPFields contains the fields of a TCP packet. It is used to describe the
|
||||
// fields of a packet that needs to be encoded.
|
||||
type TCPFields struct {
|
||||
// SrcPort is the "source port" field of a TCP packet.
|
||||
SrcPort uint16
|
||||
|
||||
// DstPort is the "destination port" field of a TCP packet.
|
||||
DstPort uint16
|
||||
|
||||
// SeqNum is the "sequence number" field of a TCP packet.
|
||||
SeqNum uint32
|
||||
|
||||
// AckNum is the "acknowledgement number" field of a TCP packet.
|
||||
AckNum uint32
|
||||
|
||||
// DataOffset is the "data offset" field of a TCP packet. It is the length of
|
||||
// the TCP header in bytes.
|
||||
DataOffset uint8
|
||||
|
||||
// Flags is the "flags" field of a TCP packet.
|
||||
Flags TCPFlags
|
||||
|
||||
// WindowSize is the "window size" field of a TCP packet.
|
||||
WindowSize uint16
|
||||
|
||||
// Checksum is the "checksum" field of a TCP packet.
|
||||
Checksum uint16
|
||||
|
||||
// UrgentPointer is the "urgent pointer" field of a TCP packet.
|
||||
UrgentPointer uint16
|
||||
}
|
||||
|
||||
// TCPSynOptions is used to return the parsed TCP Options in a syn
|
||||
// segment.
|
||||
//
|
||||
// +stateify savable
|
||||
type TCPSynOptions struct {
|
||||
// MSS is the maximum segment size provided by the peer in the SYN.
|
||||
MSS uint16
|
||||
|
||||
// WS is the window scale option provided by the peer in the SYN.
|
||||
//
|
||||
// Set to -1 if no window scale option was provided.
|
||||
WS int
|
||||
|
||||
// TS is true if the timestamp option was provided in the syn/syn-ack.
|
||||
TS bool
|
||||
|
||||
// TSVal is the value of the TSVal field in the timestamp option.
|
||||
TSVal uint32
|
||||
|
||||
// TSEcr is the value of the TSEcr field in the timestamp option.
|
||||
TSEcr uint32
|
||||
|
||||
// SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
|
||||
SACKPermitted bool
|
||||
|
||||
// Flags if specified are set on the outgoing SYN. The SYN flag is
|
||||
// always set.
|
||||
Flags TCPFlags
|
||||
}
|
||||
|
||||
// SACKBlock represents a single contiguous SACK block.
|
||||
//
|
||||
// +stateify savable
|
||||
type SACKBlock struct {
|
||||
// Start indicates the lowest sequence number in the block.
|
||||
Start seqnum.Value
|
||||
|
||||
// End indicates the sequence number immediately following the last
|
||||
// sequence number of this block.
|
||||
End seqnum.Value
|
||||
}
|
||||
|
||||
// Less returns true if r.Start < b.Start.
|
||||
func (r SACKBlock) Less(b btree.Item) bool {
|
||||
return r.Start.LessThan(b.(SACKBlock).Start)
|
||||
}
|
||||
|
||||
// Contains returns true if b is completely contained in r.
|
||||
func (r SACKBlock) Contains(b SACKBlock) bool {
|
||||
return r.Start.LessThanEq(b.Start) && b.End.LessThanEq(r.End)
|
||||
}
|
||||
|
||||
// TCPOptions are used to parse and cache the TCP segment options for a non
|
||||
// syn/syn-ack segment.
|
||||
//
|
||||
// +stateify savable
|
||||
type TCPOptions struct {
|
||||
// TS is true if the TimeStamp option is enabled.
|
||||
TS bool
|
||||
|
||||
// TSVal is the value in the TSVal field of the segment.
|
||||
TSVal uint32
|
||||
|
||||
// TSEcr is the value in the TSEcr field of the segment.
|
||||
TSEcr uint32
|
||||
|
||||
// SACKBlocks are the SACK blocks specified in the segment.
|
||||
SACKBlocks []SACKBlock
|
||||
}
|
||||
|
||||
// TCP represents a TCP header stored in a byte array.
|
||||
type TCP []byte
|
||||
|
||||
const (
|
||||
// TCPMinimumSize is the minimum size of a valid TCP packet.
|
||||
TCPMinimumSize = 20
|
||||
|
||||
// TCPOptionsMaximumSize is the maximum size of TCP options.
|
||||
TCPOptionsMaximumSize = 40
|
||||
|
||||
// TCPHeaderMaximumSize is the maximum header size of a TCP packet.
|
||||
TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize
|
||||
|
||||
// TCPTotalHeaderMaximumSize is the maximum size of headers from all layers in
|
||||
// a TCP packet. It analogous to MAX_TCP_HEADER in Linux.
|
||||
//
|
||||
// TODO(b/319936470): Investigate why this needs to be at least 140 bytes. In
|
||||
// Linux this value is at least 160, but in theory we should be able to use
|
||||
// 138. In practice anything less than 140 starts to break GSO on gVNIC
|
||||
// hardware.
|
||||
TCPTotalHeaderMaximumSize = 160
|
||||
|
||||
// TCPProtocolNumber is TCP's transport protocol number.
|
||||
TCPProtocolNumber tcpip.TransportProtocolNumber = 6
|
||||
|
||||
// TCPMinimumMSS is the minimum acceptable value for MSS. This is the
|
||||
// same as the value TCP_MIN_MSS defined net/tcp.h.
|
||||
TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize
|
||||
|
||||
// TCPMinimumSendMSS is the minimum value for MSS in a sender. This is the
|
||||
// same as the value TCP_MIN_SND_MSS in net/tcp.h.
|
||||
TCPMinimumSendMSS = TCPOptionsMaximumSize + MinIPFragmentPayloadSize
|
||||
|
||||
// TCPMaximumMSS is the maximum acceptable value for MSS.
|
||||
TCPMaximumMSS = 0xffff
|
||||
|
||||
// TCPDefaultMSS is the MSS value that should be used if an MSS option
|
||||
// is not received from the peer. It's also the value returned by
|
||||
// TCP_MAXSEG option for a socket in an unconnected state.
|
||||
//
|
||||
// Per RFC 1122, page 85: "If an MSS option is not received at
|
||||
// connection setup, TCP MUST assume a default send MSS of 536."
|
||||
TCPDefaultMSS = 536
|
||||
)
|
||||
|
||||
// SourcePort returns the "source port" field of the TCP header.
|
||||
func (b TCP) SourcePort() uint16 {
|
||||
return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
|
||||
}
|
||||
|
||||
// DestinationPort returns the "destination port" field of the TCP header.
|
||||
func (b TCP) DestinationPort() uint16 {
|
||||
return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
|
||||
}
|
||||
|
||||
// SequenceNumber returns the "sequence number" field of the TCP header.
|
||||
func (b TCP) SequenceNumber() uint32 {
|
||||
return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
|
||||
}
|
||||
|
||||
// AckNumber returns the "ack number" field of the TCP header.
|
||||
func (b TCP) AckNumber() uint32 {
|
||||
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
|
||||
}
|
||||
|
||||
// DataOffset returns the "data offset" field of the TCP header. The return
|
||||
// value is the length of the TCP header in bytes.
|
||||
func (b TCP) DataOffset() uint8 {
|
||||
return (b[TCPDataOffset] >> 4) * 4
|
||||
}
|
||||
|
||||
// Payload returns the data in the TCP packet.
|
||||
func (b TCP) Payload() []byte {
|
||||
return b[b.DataOffset():]
|
||||
}
|
||||
|
||||
// Flags returns the flags field of the TCP header.
|
||||
func (b TCP) Flags() TCPFlags {
|
||||
return TCPFlags(b[TCPFlagsOffset])
|
||||
}
|
||||
|
||||
// WindowSize returns the "window size" field of the TCP header.
|
||||
func (b TCP) WindowSize() uint16 {
|
||||
return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
|
||||
}
|
||||
|
||||
// Checksum returns the "checksum" field of the TCP header.
|
||||
func (b TCP) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
|
||||
}
|
||||
|
||||
// UrgentPointer returns the "urgent pointer" field of the TCP header.
|
||||
func (b TCP) UrgentPointer() uint16 {
|
||||
return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
|
||||
}
|
||||
|
||||
// SetSourcePort sets the "source port" field of the TCP header.
|
||||
func (b TCP) SetSourcePort(port uint16) {
|
||||
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
|
||||
}
|
||||
|
||||
// SetDestinationPort sets the "destination port" field of the TCP header.
|
||||
func (b TCP) SetDestinationPort(port uint16) {
|
||||
binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
|
||||
}
|
||||
|
||||
// SetChecksum sets the checksum field of the TCP header.
|
||||
func (b TCP) SetChecksum(xsum uint16) {
|
||||
checksum.Put(b[TCPChecksumOffset:], xsum)
|
||||
}
|
||||
|
||||
// SetDataOffset sets the data offset field of the TCP header. headerLen should
|
||||
// be the length of the TCP header in bytes.
|
||||
func (b TCP) SetDataOffset(headerLen uint8) {
|
||||
b[TCPDataOffset] = (headerLen / 4) << 4
|
||||
}
|
||||
|
||||
// SetSequenceNumber sets the sequence number field of the TCP header.
|
||||
func (b TCP) SetSequenceNumber(seqNum uint32) {
|
||||
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
|
||||
}
|
||||
|
||||
// SetAckNumber sets the ack number field of the TCP header.
|
||||
func (b TCP) SetAckNumber(ackNum uint32) {
|
||||
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
|
||||
}
|
||||
|
||||
// SetFlags sets the flags field of the TCP header.
|
||||
func (b TCP) SetFlags(flags uint8) {
|
||||
b[TCPFlagsOffset] = flags
|
||||
}
|
||||
|
||||
// SetWindowSize sets the window size field of the TCP header.
|
||||
func (b TCP) SetWindowSize(rcvwnd uint16) {
|
||||
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
|
||||
}
|
||||
|
||||
// SetUrgentPointer sets the window size field of the TCP header.
|
||||
func (b TCP) SetUrgentPointer(urgentPointer uint16) {
|
||||
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
|
||||
}
|
||||
|
||||
// CalculateChecksum calculates the checksum of the TCP segment.
|
||||
// partialChecksum is the checksum of the network-layer pseudo-header
|
||||
// and the checksum of the segment data.
|
||||
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the TCP header's checksum is valid.
|
||||
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Options returns a slice that holds the unparsed TCP options in the segment.
|
||||
func (b TCP) Options() []byte {
|
||||
return b[TCPMinimumSize:b.DataOffset()]
|
||||
}
|
||||
|
||||
// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
|
||||
// option values in the TCP segment. NOTE: Invoking this function repeatedly is
|
||||
// expensive as it reparses the options on each invocation.
|
||||
func (b TCP) ParsedOptions() TCPOptions {
|
||||
return ParseTCPOptions(b.Options())
|
||||
}
|
||||
|
||||
func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
|
||||
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
|
||||
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
|
||||
b[TCPFlagsOffset] = uint8(flags)
|
||||
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the TCP header.
|
||||
func (b TCP) Encode(t *TCPFields) {
|
||||
b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
|
||||
b.SetSourcePort(t.SrcPort)
|
||||
b.SetDestinationPort(t.DstPort)
|
||||
b.SetDataOffset(t.DataOffset)
|
||||
b.SetChecksum(t.Checksum)
|
||||
b.SetUrgentPointer(t.UrgentPointer)
|
||||
}
|
||||
|
||||
// EncodePartial updates a subset of the fields of the TCP header. It is useful
|
||||
// in cases when similar segments are produced.
|
||||
func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
|
||||
// Add the total length and "flags" field contributions to the checksum.
|
||||
// We don't use the flags field directly from the header because it's a
|
||||
// one-byte field with an odd offset, so it would be accounted for
|
||||
// incorrectly by the Checksum routine.
|
||||
tmp := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(tmp, length)
|
||||
binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
|
||||
xsum := checksum.Checksum(tmp, partialChecksum)
|
||||
|
||||
// Encode the passed-in fields.
|
||||
b.encodeSubset(seqnum, acknum, flags, rcvwnd)
|
||||
|
||||
// Add the contributions of the passed-in fields to the checksum.
|
||||
xsum = checksum.Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], xsum)
|
||||
xsum = checksum.Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], xsum)
|
||||
|
||||
// Encode the checksum.
|
||||
b.SetChecksum(^xsum)
|
||||
}
|
||||
|
||||
// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
|
||||
func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
|
||||
old := b.SourcePort()
|
||||
b.SetSourcePort(new)
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
|
||||
}
|
||||
|
||||
// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
|
||||
func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
|
||||
old := b.DestinationPort()
|
||||
b.SetDestinationPort(new)
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
|
||||
}
|
||||
|
||||
// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
|
||||
func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
|
||||
xsum := b.Checksum()
|
||||
if fullChecksum {
|
||||
xsum = ^xsum
|
||||
}
|
||||
|
||||
xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
|
||||
if fullChecksum {
|
||||
xsum = ^xsum
|
||||
}
|
||||
|
||||
b.SetChecksum(xsum)
|
||||
}
|
||||
|
||||
// ParseSynOptions parses the options received in a SYN segment and returns the
|
||||
// relevant ones. opts should point to the option part of the TCP header.
|
||||
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
|
||||
limit := len(opts)
|
||||
|
||||
synOpts := TCPSynOptions{
|
||||
// Per RFC 1122, page 85: "If an MSS option is not received at
|
||||
// connection setup, TCP MUST assume a default send MSS of 536."
|
||||
MSS: TCPDefaultMSS,
|
||||
// If no window scale option is specified, WS in options is
|
||||
// returned as -1; this is because the absence of the option
|
||||
// indicates that the we cannot use window scaling on the
|
||||
// receive end either.
|
||||
WS: -1,
|
||||
}
|
||||
|
||||
for i := 0; i < limit; {
|
||||
switch opts[i] {
|
||||
case TCPOptionEOL:
|
||||
i = limit
|
||||
case TCPOptionNOP:
|
||||
i++
|
||||
case TCPOptionMSS:
|
||||
if i+4 > limit || opts[i+1] != 4 {
|
||||
return synOpts
|
||||
}
|
||||
mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
|
||||
if mss == 0 {
|
||||
return synOpts
|
||||
}
|
||||
synOpts.MSS = mss
|
||||
if mss < TCPMinimumSendMSS {
|
||||
synOpts.MSS = TCPMinimumSendMSS
|
||||
}
|
||||
i += 4
|
||||
|
||||
case TCPOptionWS:
|
||||
if i+3 > limit || opts[i+1] != 3 {
|
||||
return synOpts
|
||||
}
|
||||
ws := int(opts[i+2])
|
||||
if ws > MaxWndScale {
|
||||
ws = MaxWndScale
|
||||
}
|
||||
synOpts.WS = ws
|
||||
i += 3
|
||||
|
||||
case TCPOptionTS:
|
||||
if i+10 > limit || opts[i+1] != 10 {
|
||||
return synOpts
|
||||
}
|
||||
synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
|
||||
if isAck {
|
||||
// If the segment is a SYN-ACK then store the Timestamp Echo Reply
|
||||
// in the segment.
|
||||
synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
|
||||
}
|
||||
synOpts.TS = true
|
||||
i += 10
|
||||
case TCPOptionSACKPermitted:
|
||||
if i+2 > limit || opts[i+1] != 2 {
|
||||
return synOpts
|
||||
}
|
||||
synOpts.SACKPermitted = true
|
||||
i += 2
|
||||
|
||||
default:
|
||||
// We don't recognize this option, just skip over it.
|
||||
if i+2 > limit {
|
||||
return synOpts
|
||||
}
|
||||
l := int(opts[i+1])
|
||||
// If the length is incorrect or if l+i overflows the
|
||||
// total options length then return false.
|
||||
if l < 2 || i+l > limit {
|
||||
return synOpts
|
||||
}
|
||||
i += l
|
||||
}
|
||||
}
|
||||
|
||||
return synOpts
|
||||
}
|
||||
|
||||
// ParseTCPOptions extracts and stores all known options in the provided byte
|
||||
// slice in a TCPOptions structure.
|
||||
func ParseTCPOptions(b []byte) TCPOptions {
|
||||
opts := TCPOptions{}
|
||||
limit := len(b)
|
||||
for i := 0; i < limit; {
|
||||
switch b[i] {
|
||||
case TCPOptionEOL:
|
||||
i = limit
|
||||
case TCPOptionNOP:
|
||||
i++
|
||||
case TCPOptionTS:
|
||||
if i+10 > limit || (b[i+1] != 10) {
|
||||
return opts
|
||||
}
|
||||
opts.TS = true
|
||||
opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
|
||||
opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
|
||||
i += 10
|
||||
case TCPOptionSACK:
|
||||
if i+2 > limit {
|
||||
// Malformed SACK block, just return and stop parsing.
|
||||
return opts
|
||||
}
|
||||
sackOptionLen := int(b[i+1])
|
||||
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
|
||||
// Malformed SACK block, just return and stop parsing.
|
||||
return opts
|
||||
}
|
||||
numBlocks := (sackOptionLen - 2) / 8
|
||||
opts.SACKBlocks = []SACKBlock{}
|
||||
for j := 0; j < numBlocks; j++ {
|
||||
start := binary.BigEndian.Uint32(b[i+2+j*8:])
|
||||
end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
|
||||
opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
|
||||
Start: seqnum.Value(start),
|
||||
End: seqnum.Value(end),
|
||||
})
|
||||
}
|
||||
i += sackOptionLen
|
||||
default:
|
||||
// We don't recognize this option, just skip over it.
|
||||
if i+2 > limit {
|
||||
return opts
|
||||
}
|
||||
l := int(b[i+1])
|
||||
// If the length is incorrect or if l+i overflows the
|
||||
// total options length then return false.
|
||||
if l < 2 || i+l > limit {
|
||||
return opts
|
||||
}
|
||||
i += l
|
||||
}
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
|
||||
// the supplied buffer. If the provided buffer is not large enough then it just
|
||||
// returns without encoding anything. It returns the number of bytes written to
|
||||
// the provided buffer.
|
||||
func EncodeMSSOption(mss uint32, b []byte) int {
|
||||
if len(b) < TCPOptionMSSLength {
|
||||
return 0
|
||||
}
|
||||
b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss)
|
||||
return TCPOptionMSSLength
|
||||
}
|
||||
|
||||
// EncodeWSOption encodes the WS TCP option with the WS value in the
|
||||
// provided buffer. If the provided buffer is not large enough then it just
|
||||
// returns without encoding anything. It returns the number of bytes written to
|
||||
// the provided buffer.
|
||||
func EncodeWSOption(ws int, b []byte) int {
|
||||
if len(b) < TCPOptionWSLength {
|
||||
return 0
|
||||
}
|
||||
b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws)
|
||||
return int(b[1])
|
||||
}
|
||||
|
||||
// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
|
||||
// option into the provided buffer. If the buffer is smaller than expected it
|
||||
// just returns without encoding anything. It returns the number of bytes
|
||||
// written to the provided buffer.
|
||||
func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
|
||||
if len(b) < TCPOptionTSLength {
|
||||
return 0
|
||||
}
|
||||
b[0], b[1] = TCPOptionTS, TCPOptionTSLength
|
||||
binary.BigEndian.PutUint32(b[2:], tsVal)
|
||||
binary.BigEndian.PutUint32(b[6:], tsEcr)
|
||||
return int(b[1])
|
||||
}
|
||||
|
||||
// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
|
||||
// buffer. If the buffer is smaller than required it just returns without
|
||||
// encoding anything. It returns the number of bytes written to the provided
|
||||
// buffer.
|
||||
func EncodeSACKPermittedOption(b []byte) int {
|
||||
if len(b) < TCPOptionSackPermittedLength {
|
||||
return 0
|
||||
}
|
||||
|
||||
b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength
|
||||
return int(b[1])
|
||||
}
|
||||
|
||||
// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
|
||||
// in the provided slice. It tries to fit in as many blocks as possible based on
|
||||
// number of bytes available in the provided buffer. It returns the number of
|
||||
// bytes written to the provided buffer.
|
||||
func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
|
||||
if len(sackBlocks) == 0 {
|
||||
return 0
|
||||
}
|
||||
l := len(sackBlocks)
|
||||
if l > TCPMaxSACKBlocks {
|
||||
l = TCPMaxSACKBlocks
|
||||
}
|
||||
if ll := (len(b) - 2) / 8; ll < l {
|
||||
l = ll
|
||||
}
|
||||
if l == 0 {
|
||||
// There is not enough space in the provided buffer to add
|
||||
// any SACK blocks.
|
||||
return 0
|
||||
}
|
||||
b[0] = TCPOptionSACK
|
||||
b[1] = byte(l*8 + 2)
|
||||
for i := 0; i < l; i++ {
|
||||
binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
|
||||
binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
|
||||
}
|
||||
return int(b[1])
|
||||
}
|
||||
|
||||
// EncodeNOP adds an explicit NOP to the option list.
|
||||
func EncodeNOP(b []byte) int {
|
||||
if len(b) == 0 {
|
||||
return 0
|
||||
}
|
||||
b[0] = TCPOptionNOP
|
||||
return 1
|
||||
}
|
||||
|
||||
// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
|
||||
// the option buffer. It adds padding bytes after the offset specified and
|
||||
// returns the number of padding bytes added. The passed in options slice
|
||||
// must have space for the padding bytes.
|
||||
func AddTCPOptionPadding(options []byte, offset int) int {
|
||||
paddingToAdd := -offset & 3
|
||||
// Now add any padding bytes that might be required to quad align the
|
||||
// options.
|
||||
for i := offset; i < offset+paddingToAdd; i++ {
|
||||
options[i] = TCPOptionNOP
|
||||
}
|
||||
return paddingToAdd
|
||||
}
|
||||
|
||||
// Acceptable checks if a segment that starts at segSeq and has length segLen is
|
||||
// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
|
||||
// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
|
||||
func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
|
||||
if rcvNxt == rcvAcc {
|
||||
return segLen == 0 && segSeq == rcvNxt
|
||||
}
|
||||
if segLen == 0 {
|
||||
// rcvWnd is incremented by 1 because that is Linux's behavior despite the
|
||||
// RFC.
|
||||
return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
|
||||
}
|
||||
// Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
|
||||
// the payload, so we'll accept any payload that overlaps the receive window.
|
||||
// segSeq < rcvAcc is more correct according to RFC, however, Linux does it
|
||||
// differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior
|
||||
// as Linux.
|
||||
return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc)
|
||||
}
|
||||
|
||||
// TCPValid returns true if the pkt has a valid TCP header. It checks whether:
|
||||
// - The data offset is too small.
|
||||
// - The data offset is too large.
|
||||
// - The checksum is invalid.
|
||||
//
|
||||
// TCPValid corresponds to net/netfilter/nf_conntrack_proto_tcp.c:tcp_error.
|
||||
func TCPValid(hdr TCP, payloadChecksum func() uint16, payloadSize uint16, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (csum uint16, csumValid, ok bool) {
|
||||
if offset := int(hdr.DataOffset()); offset < TCPMinimumSize || offset > len(hdr) {
|
||||
return
|
||||
}
|
||||
|
||||
if skipChecksumValidation {
|
||||
csumValid = true
|
||||
} else {
|
||||
csum = hdr.Checksum()
|
||||
csumValid = hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum(), payloadSize)
|
||||
}
|
||||
return csum, csumValid, true
|
||||
}
|
||||
@@ -1,195 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package header
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
)
|
||||
|
||||
const (
|
||||
udpSrcPort = 0
|
||||
udpDstPort = 2
|
||||
udpLength = 4
|
||||
udpChecksum = 6
|
||||
)
|
||||
|
||||
const (
|
||||
// UDPMaximumPacketSize is the largest possible UDP packet.
|
||||
UDPMaximumPacketSize = 0xffff
|
||||
)
|
||||
|
||||
// UDPFields contains the fields of a UDP packet. It is used to describe the
|
||||
// fields of a packet that needs to be encoded.
|
||||
type UDPFields struct {
|
||||
// SrcPort is the "source port" field of a UDP packet.
|
||||
SrcPort uint16
|
||||
|
||||
// DstPort is the "destination port" field of a UDP packet.
|
||||
DstPort uint16
|
||||
|
||||
// Length is the "length" field of a UDP packet.
|
||||
Length uint16
|
||||
|
||||
// Checksum is the "checksum" field of a UDP packet.
|
||||
Checksum uint16
|
||||
}
|
||||
|
||||
// UDP represents a UDP header stored in a byte array.
|
||||
type UDP []byte
|
||||
|
||||
const (
|
||||
// UDPMinimumSize is the minimum size of a valid UDP packet.
|
||||
UDPMinimumSize = 8
|
||||
|
||||
// UDPMaximumSize is the maximum size of a valid UDP packet. The length field
|
||||
// in the UDP header is 16 bits as per RFC 768.
|
||||
UDPMaximumSize = math.MaxUint16
|
||||
|
||||
// UDPProtocolNumber is UDP's transport protocol number.
|
||||
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
|
||||
)
|
||||
|
||||
// SourcePort returns the "source port" field of the UDP header.
|
||||
func (b UDP) SourcePort() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpSrcPort:])
|
||||
}
|
||||
|
||||
// DestinationPort returns the "destination port" field of the UDP header.
|
||||
func (b UDP) DestinationPort() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpDstPort:])
|
||||
}
|
||||
|
||||
// Length returns the "length" field of the UDP header.
|
||||
func (b UDP) Length() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpLength:])
|
||||
}
|
||||
|
||||
// Payload returns the data contained in the UDP datagram.
|
||||
func (b UDP) Payload() []byte {
|
||||
return b[UDPMinimumSize:]
|
||||
}
|
||||
|
||||
// Checksum returns the "checksum" field of the UDP header.
|
||||
func (b UDP) Checksum() uint16 {
|
||||
return binary.BigEndian.Uint16(b[udpChecksum:])
|
||||
}
|
||||
|
||||
// SetSourcePort sets the "source port" field of the UDP header.
|
||||
func (b UDP) SetSourcePort(port uint16) {
|
||||
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
|
||||
}
|
||||
|
||||
// SetDestinationPort sets the "destination port" field of the UDP header.
|
||||
func (b UDP) SetDestinationPort(port uint16) {
|
||||
binary.BigEndian.PutUint16(b[udpDstPort:], port)
|
||||
}
|
||||
|
||||
// SetChecksum sets the "checksum" field of the UDP header.
|
||||
func (b UDP) SetChecksum(xsum uint16) {
|
||||
checksum.Put(b[udpChecksum:], xsum)
|
||||
}
|
||||
|
||||
// SetLength sets the "length" field of the UDP header.
|
||||
func (b UDP) SetLength(length uint16) {
|
||||
binary.BigEndian.PutUint16(b[udpLength:], length)
|
||||
}
|
||||
|
||||
// CalculateChecksum calculates the checksum of the UDP packet, given the
|
||||
// checksum of the network-layer pseudo-header and the checksum of the payload.
|
||||
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
|
||||
// Calculate the rest of the checksum.
|
||||
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
|
||||
}
|
||||
|
||||
// IsChecksumValid returns true iff the UDP header's checksum is valid.
|
||||
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
|
||||
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
|
||||
xsum = checksum.Combine(xsum, payloadChecksum)
|
||||
return b.CalculateChecksum(xsum) == 0xffff
|
||||
}
|
||||
|
||||
// Encode encodes all the fields of the UDP header.
|
||||
func (b UDP) Encode(u *UDPFields) {
|
||||
b.SetSourcePort(u.SrcPort)
|
||||
b.SetDestinationPort(u.DstPort)
|
||||
b.SetLength(u.Length)
|
||||
b.SetChecksum(u.Checksum)
|
||||
}
|
||||
|
||||
// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
|
||||
func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
|
||||
old := b.SourcePort()
|
||||
b.SetSourcePort(new)
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
|
||||
}
|
||||
|
||||
// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
|
||||
func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
|
||||
old := b.DestinationPort()
|
||||
b.SetDestinationPort(new)
|
||||
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
|
||||
}
|
||||
|
||||
// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
|
||||
func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
|
||||
xsum := b.Checksum()
|
||||
if fullChecksum {
|
||||
xsum = ^xsum
|
||||
}
|
||||
|
||||
xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
|
||||
if fullChecksum {
|
||||
xsum = ^xsum
|
||||
}
|
||||
|
||||
b.SetChecksum(xsum)
|
||||
}
|
||||
|
||||
// UDPValid returns true if the pkt has a valid UDP header. It checks whether:
|
||||
// - The length field is too small.
|
||||
// - The length field is too large.
|
||||
// - The checksum is invalid.
|
||||
//
|
||||
// UDPValid corresponds to net/netfilter/nf_conntrack_proto_udp.c:udp_error.
|
||||
func UDPValid(hdr UDP, payloadChecksum func() uint16, payloadSize uint16, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (lengthValid, csumValid bool) {
|
||||
if length := hdr.Length(); length > payloadSize+UDPMinimumSize || length < UDPMinimumSize {
|
||||
return false, false
|
||||
}
|
||||
|
||||
if skipChecksumValidation {
|
||||
return true, true
|
||||
}
|
||||
|
||||
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
|
||||
// omitted the checksum generation, as per RFC 768:
|
||||
//
|
||||
// An all zero transmitted checksum value means that the transmitter
|
||||
// generated no checksum (for debugging or for higher level protocols that
|
||||
// don't care).
|
||||
//
|
||||
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
|
||||
//
|
||||
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
|
||||
// checksum is not optional.
|
||||
if netProto == IPv4ProtocolNumber && hdr.Checksum() == 0 {
|
||||
return true, true
|
||||
}
|
||||
|
||||
return true, hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum())
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
// Copyright 2018 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package seqnum defines the types and methods for TCP sequence numbers such
|
||||
// that they fit in 32-bit words and work properly when overflows occur.
|
||||
package seqnum
|
||||
|
||||
// Value represents the value of a sequence number.
|
||||
type Value uint32
|
||||
|
||||
// Size represents the size (length) of a sequence number window.
|
||||
type Size uint32
|
||||
|
||||
// LessThan checks if v is before w, i.e., v < w.
|
||||
func (v Value) LessThan(w Value) bool {
|
||||
return int32(v-w) < 0
|
||||
}
|
||||
|
||||
// LessThanEq returns true if v==w or v is before i.e., v < w.
|
||||
func (v Value) LessThanEq(w Value) bool {
|
||||
if v == w {
|
||||
return true
|
||||
}
|
||||
return v.LessThan(w)
|
||||
}
|
||||
|
||||
// InRange checks if v is in the range [a,b), i.e., a <= v < b.
|
||||
func (v Value) InRange(a, b Value) bool {
|
||||
return v-a < b-a
|
||||
}
|
||||
|
||||
// InWindow checks if v is in the window that starts at 'first' and spans 'size'
|
||||
// sequence numbers.
|
||||
func (v Value) InWindow(first Value, size Size) bool {
|
||||
return v.InRange(first, first.Add(size))
|
||||
}
|
||||
|
||||
// Add calculates the sequence number following the [v, v+s) window.
|
||||
func (v Value) Add(s Size) Value {
|
||||
return v + Value(s)
|
||||
}
|
||||
|
||||
// Size calculates the size of the window defined by [v, w).
|
||||
func (v Value) Size(w Value) Size {
|
||||
return Size(w - v)
|
||||
}
|
||||
|
||||
// UpdateForward updates v such that it becomes v + s.
|
||||
func (v *Value) UpdateForward(s Size) {
|
||||
*v += Value(s)
|
||||
}
|
||||
@@ -1,573 +0,0 @@
|
||||
// Copyright 2024 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tcpip
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/bits"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Using the header package here would cause an import cycle.
|
||||
const (
|
||||
ipv4AddressSize = 4
|
||||
ipv4ProtocolNumber = 0x0800
|
||||
ipv6AddressSize = 16
|
||||
ipv6ProtocolNumber = 0x86dd
|
||||
)
|
||||
|
||||
const (
|
||||
// LinkAddressSize is the size of a MAC address.
|
||||
LinkAddressSize = 6
|
||||
)
|
||||
|
||||
// Known IP address.
|
||||
var (
|
||||
IPv4Zero = []byte{0, 0, 0, 0}
|
||||
IPv6Zero = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
|
||||
)
|
||||
|
||||
// Errors related to Subnet
|
||||
var (
|
||||
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
|
||||
errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
|
||||
)
|
||||
|
||||
// TransportProtocolNumber is the number of a transport protocol.
|
||||
type TransportProtocolNumber uint32
|
||||
|
||||
// NetworkProtocolNumber is the EtherType of a network protocol in an Ethernet
|
||||
// frame.
|
||||
//
|
||||
// See: https://www.iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml
|
||||
type NetworkProtocolNumber uint32
|
||||
|
||||
// MonotonicTime is a monotonic clock reading.
|
||||
//
|
||||
// +stateify savable
|
||||
type MonotonicTime struct {
|
||||
nanoseconds int64
|
||||
}
|
||||
|
||||
// String implements Stringer.
|
||||
func (mt MonotonicTime) String() string {
|
||||
return strconv.FormatInt(mt.nanoseconds, 10)
|
||||
}
|
||||
|
||||
// MonotonicTimeInfinite returns the monotonic timestamp as far away in the
|
||||
// future as possible.
|
||||
func MonotonicTimeInfinite() MonotonicTime {
|
||||
return MonotonicTime{nanoseconds: math.MaxInt64}
|
||||
}
|
||||
|
||||
// Before reports whether the monotonic clock reading mt is before u.
|
||||
func (mt MonotonicTime) Before(u MonotonicTime) bool {
|
||||
return mt.nanoseconds < u.nanoseconds
|
||||
}
|
||||
|
||||
// After reports whether the monotonic clock reading mt is after u.
|
||||
func (mt MonotonicTime) After(u MonotonicTime) bool {
|
||||
return mt.nanoseconds > u.nanoseconds
|
||||
}
|
||||
|
||||
// Add returns the monotonic clock reading mt+d.
|
||||
func (mt MonotonicTime) Add(d time.Duration) MonotonicTime {
|
||||
return MonotonicTime{
|
||||
nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(),
|
||||
}
|
||||
}
|
||||
|
||||
// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum)
|
||||
// value that can be stored in a Duration, the maximum (or minimum) duration
|
||||
// will be returned. To compute t-d for a duration d, use t.Add(-d).
|
||||
func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration {
|
||||
return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds))
|
||||
}
|
||||
|
||||
// Milliseconds returns the time in milliseconds.
|
||||
func (mt MonotonicTime) Milliseconds() int64 {
|
||||
return mt.nanoseconds / 1e6
|
||||
}
|
||||
|
||||
// A Clock provides the current time and schedules work for execution.
|
||||
//
|
||||
// Times returned by a Clock should always be used for application-visible
|
||||
// time. Only monotonic times should be used for netstack internal timekeeping.
|
||||
type Clock interface {
|
||||
// Now returns the current local time.
|
||||
Now() time.Time
|
||||
|
||||
// NowMonotonic returns the current monotonic clock reading.
|
||||
NowMonotonic() MonotonicTime
|
||||
|
||||
// AfterFunc waits for the duration to elapse and then calls f in its own
|
||||
// goroutine. It returns a Timer that can be used to cancel the call using
|
||||
// its Stop method.
|
||||
AfterFunc(d time.Duration, f func()) Timer
|
||||
}
|
||||
|
||||
// Timer represents a single event. A Timer must be created with
|
||||
// Clock.AfterFunc.
|
||||
type Timer interface {
|
||||
// Stop prevents the Timer from firing. It returns true if the call stops the
|
||||
// timer, false if the timer has already expired or been stopped.
|
||||
//
|
||||
// If Stop returns false, then the timer has already expired and the function
|
||||
// f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
|
||||
// does not wait for f to complete before returning. If the caller needs to
|
||||
// know whether f is completed, it must coordinate with f explicitly.
|
||||
Stop() bool
|
||||
|
||||
// Reset changes the timer to expire after duration d.
|
||||
//
|
||||
// Reset should be invoked only on stopped or expired timers. If the timer is
|
||||
// known to have expired, Reset can be used directly. Otherwise, the caller
|
||||
// must coordinate with the function f of Clock.AfterFunc(d, f).
|
||||
Reset(d time.Duration)
|
||||
}
|
||||
|
||||
// Address is a byte slice cast as a string that represents the address of a
|
||||
// network node. Or, in the case of unix endpoints, it may represent a path.
|
||||
//
|
||||
// +stateify savable
|
||||
type Address struct {
|
||||
addr [16]byte
|
||||
length int
|
||||
}
|
||||
|
||||
// AddrFrom4 converts addr to an Address.
|
||||
func AddrFrom4(addr [4]byte) Address {
|
||||
ret := Address{
|
||||
length: 4,
|
||||
}
|
||||
// It's guaranteed that copy will return 4.
|
||||
copy(ret.addr[:], addr[:])
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddrFrom4Slice converts addr to an Address. It panics if len(addr) != 4.
|
||||
func AddrFrom4Slice(addr []byte) Address {
|
||||
if len(addr) != 4 {
|
||||
panic(fmt.Sprintf("bad address length for address %v", addr))
|
||||
}
|
||||
ret := Address{
|
||||
length: 4,
|
||||
}
|
||||
// It's guaranteed that copy will return 4.
|
||||
copy(ret.addr[:], addr)
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddrFrom16 converts addr to an Address.
|
||||
func AddrFrom16(addr [16]byte) Address {
|
||||
ret := Address{
|
||||
length: 16,
|
||||
}
|
||||
// It's guaranteed that copy will return 16.
|
||||
copy(ret.addr[:], addr[:])
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddrFrom16Slice converts addr to an Address. It panics if len(addr) != 16.
|
||||
func AddrFrom16Slice(addr []byte) Address {
|
||||
if len(addr) != 16 {
|
||||
panic(fmt.Sprintf("bad address length for address %v", addr))
|
||||
}
|
||||
ret := Address{
|
||||
length: 16,
|
||||
}
|
||||
// It's guaranteed that copy will return 16.
|
||||
copy(ret.addr[:], addr)
|
||||
return ret
|
||||
}
|
||||
|
||||
// AddrFromSlice converts addr to an Address. It returns the Address zero value
|
||||
// if len(addr) != 4 or 16.
|
||||
func AddrFromSlice(addr []byte) Address {
|
||||
switch len(addr) {
|
||||
case ipv4AddressSize:
|
||||
return AddrFrom4Slice(addr)
|
||||
case ipv6AddressSize:
|
||||
return AddrFrom16Slice(addr)
|
||||
}
|
||||
return Address{}
|
||||
}
|
||||
|
||||
// As4 returns a as a 4 byte array. It panics if the address length is not 4.
|
||||
func (a Address) As4() [4]byte {
|
||||
if a.Len() != 4 {
|
||||
panic(fmt.Sprintf("bad address length for address %v", a.addr))
|
||||
}
|
||||
return [4]byte(a.addr[:4])
|
||||
}
|
||||
|
||||
// As16 returns a as a 16 byte array. It panics if the address length is not 16.
|
||||
func (a Address) As16() [16]byte {
|
||||
if a.Len() != 16 {
|
||||
panic(fmt.Sprintf("bad address length for address %v", a.addr))
|
||||
}
|
||||
return [16]byte(a.addr[:16])
|
||||
}
|
||||
|
||||
// AsSlice returns a as a byte slice. Callers should be careful as it can
|
||||
// return a window into existing memory.
|
||||
//
|
||||
// +checkescape
|
||||
func (a *Address) AsSlice() []byte {
|
||||
return a.addr[:a.length]
|
||||
}
|
||||
|
||||
// BitLen returns the length in bits of a.
|
||||
func (a Address) BitLen() int {
|
||||
return a.Len() * 8
|
||||
}
|
||||
|
||||
// Len returns the length in bytes of a.
|
||||
func (a Address) Len() int {
|
||||
return a.length
|
||||
}
|
||||
|
||||
// WithPrefix returns the address with a prefix that represents a point subnet.
|
||||
func (a Address) WithPrefix() AddressWithPrefix {
|
||||
return AddressWithPrefix{
|
||||
Address: a,
|
||||
PrefixLen: a.BitLen(),
|
||||
}
|
||||
}
|
||||
|
||||
// Unspecified returns true if the address is unspecified.
|
||||
func (a Address) Unspecified() bool {
|
||||
for _, b := range a.addr {
|
||||
if b != 0 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Equal returns whether a and other are equal. It exists for use by the cmp
|
||||
// library.
|
||||
func (a Address) Equal(other Address) bool {
|
||||
return a == other
|
||||
}
|
||||
|
||||
// MatchingPrefix returns the matching prefix length in bits.
|
||||
//
|
||||
// Panics if b and a have different lengths.
|
||||
func (a Address) MatchingPrefix(b Address) uint8 {
|
||||
const bitsInAByte = 8
|
||||
|
||||
if a.Len() != b.Len() {
|
||||
panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b))
|
||||
}
|
||||
|
||||
var prefix uint8
|
||||
for i := 0; i < a.length; i++ {
|
||||
aByte := a.addr[i]
|
||||
bByte := b.addr[i]
|
||||
|
||||
if aByte == bByte {
|
||||
prefix += bitsInAByte
|
||||
continue
|
||||
}
|
||||
|
||||
// Count the remaining matching bits in the byte from MSbit to LSBbit.
|
||||
mask := uint8(1) << (bitsInAByte - 1)
|
||||
for {
|
||||
if aByte&mask == bByte&mask {
|
||||
prefix++
|
||||
mask >>= 1
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return prefix
|
||||
}
|
||||
|
||||
// AddressMask is a bitmask for an address.
|
||||
//
|
||||
// +stateify savable
|
||||
type AddressMask struct {
|
||||
mask [16]byte
|
||||
length int
|
||||
}
|
||||
|
||||
// MaskFrom returns a Mask based on str.
|
||||
//
|
||||
// MaskFrom may allocate, and so should not be in hot paths.
|
||||
func MaskFrom(str string) AddressMask {
|
||||
mask := AddressMask{length: len(str)}
|
||||
copy(mask.mask[:], str)
|
||||
return mask
|
||||
}
|
||||
|
||||
// MaskFromBytes returns a Mask based on bs.
|
||||
func MaskFromBytes(bs []byte) AddressMask {
|
||||
mask := AddressMask{length: len(bs)}
|
||||
copy(mask.mask[:], bs)
|
||||
return mask
|
||||
}
|
||||
|
||||
// String implements Stringer.
|
||||
func (m AddressMask) String() string {
|
||||
return fmt.Sprintf("%x", m.mask)
|
||||
}
|
||||
|
||||
// AsSlice returns a as a byte slice. Callers should be careful as it can
|
||||
// return a window into existing memory.
|
||||
func (m *AddressMask) AsSlice() []byte {
|
||||
return []byte(m.mask[:m.length])
|
||||
}
|
||||
|
||||
// BitLen returns the length of the mask in bits.
|
||||
func (m AddressMask) BitLen() int {
|
||||
return m.length * 8
|
||||
}
|
||||
|
||||
// Len returns the length of the mask in bytes.
|
||||
func (m AddressMask) Len() int {
|
||||
return m.length
|
||||
}
|
||||
|
||||
// Prefix returns the number of bits before the first host bit.
|
||||
func (m AddressMask) Prefix() int {
|
||||
p := 0
|
||||
for _, b := range m.mask[:m.length] {
|
||||
p += bits.LeadingZeros8(^b)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// Equal returns whether m and other are equal. It exists for use by the cmp
|
||||
// library.
|
||||
func (m AddressMask) Equal(other AddressMask) bool {
|
||||
return m == other
|
||||
}
|
||||
|
||||
// Subnet is a subnet defined by its address and mask.
|
||||
//
|
||||
// +stateify savable
|
||||
type Subnet struct {
|
||||
address Address
|
||||
mask AddressMask
|
||||
}
|
||||
|
||||
// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
|
||||
func NewSubnet(a Address, m AddressMask) (Subnet, error) {
|
||||
if a.Len() != m.Len() {
|
||||
return Subnet{}, errSubnetLengthMismatch
|
||||
}
|
||||
for i := 0; i < a.Len(); i++ {
|
||||
if a.addr[i]&^m.mask[i] != 0 {
|
||||
return Subnet{}, errSubnetAddressMasked
|
||||
}
|
||||
}
|
||||
return Subnet{a, m}, nil
|
||||
}
|
||||
|
||||
// String implements Stringer.
|
||||
func (s Subnet) String() string {
|
||||
return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
|
||||
}
|
||||
|
||||
// Contains returns true iff the address is of the same length and matches the
|
||||
// subnet address and mask.
|
||||
func (s *Subnet) Contains(a Address) bool {
|
||||
if a.Len() != s.address.Len() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < a.Len(); i++ {
|
||||
if a.addr[i]&s.mask.mask[i] != s.address.addr[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ID returns the subnet ID.
|
||||
func (s *Subnet) ID() Address {
|
||||
return s.address
|
||||
}
|
||||
|
||||
// Bits returns the number of ones (network bits) and zeros (host bits) in the
|
||||
// subnet mask.
|
||||
func (s *Subnet) Bits() (ones int, zeros int) {
|
||||
ones = s.mask.Prefix()
|
||||
return ones, s.mask.BitLen() - ones
|
||||
}
|
||||
|
||||
// Prefix returns the number of bits before the first host bit.
|
||||
func (s *Subnet) Prefix() int {
|
||||
return s.mask.Prefix()
|
||||
}
|
||||
|
||||
// Mask returns the subnet mask.
|
||||
func (s *Subnet) Mask() AddressMask {
|
||||
return s.mask
|
||||
}
|
||||
|
||||
// Broadcast returns the subnet's broadcast address.
|
||||
func (s *Subnet) Broadcast() Address {
|
||||
addrCopy := s.address
|
||||
for i := 0; i < addrCopy.Len(); i++ {
|
||||
addrCopy.addr[i] |= ^s.mask.mask[i]
|
||||
}
|
||||
return addrCopy
|
||||
}
|
||||
|
||||
// IsBroadcast returns true if the address is considered a broadcast address.
|
||||
func (s *Subnet) IsBroadcast(address Address) bool {
|
||||
// Only IPv4 supports the notion of a broadcast address.
|
||||
if address.Len() != ipv4AddressSize {
|
||||
return false
|
||||
}
|
||||
|
||||
// Normally, we would just compare address with the subnet's broadcast
|
||||
// address but there is an exception where a simple comparison is not
|
||||
// correct. This exception is for /31 and /32 IPv4 subnets where all
|
||||
// addresses are considered valid host addresses.
|
||||
//
|
||||
// For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that
|
||||
// both addresses in a /31 subnet "MUST be interpreted as host addresses."
|
||||
//
|
||||
// For /32, the case is a bit more vague. RFC 3021 makes no mention of /32
|
||||
// subnets. However, the same reasoning applies - if an exception is not
|
||||
// made, then there do not exist any host addresses in a /32 subnet. RFC
|
||||
// 4632 Section 3.1 also vaguely implies this interpretation by referring
|
||||
// to addresses in /32 subnets as "host routes."
|
||||
return s.Prefix() <= 30 && s.Broadcast() == address
|
||||
}
|
||||
|
||||
// Equal returns true if this Subnet is equal to the given Subnet.
|
||||
func (s Subnet) Equal(o Subnet) bool {
|
||||
// If this changes, update Route.Equal accordingly.
|
||||
return s == o
|
||||
}
|
||||
|
||||
// LinkAddress is a byte slice cast as a string that represents a link address.
|
||||
// It is typically a 6-byte MAC address.
|
||||
type LinkAddress string
|
||||
|
||||
// String implements the fmt.Stringer interface.
|
||||
func (a LinkAddress) String() string {
|
||||
switch len(a) {
|
||||
case 6:
|
||||
return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
|
||||
default:
|
||||
return fmt.Sprintf("%x", []byte(a))
|
||||
}
|
||||
}
|
||||
|
||||
// ParseMACAddress parses an IEEE 802 address.
|
||||
//
|
||||
// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
|
||||
func ParseMACAddress(s string) (LinkAddress, error) {
|
||||
parts := strings.FieldsFunc(s, func(c rune) bool {
|
||||
return c == ':' || c == '-'
|
||||
})
|
||||
if len(parts) != LinkAddressSize {
|
||||
return "", fmt.Errorf("inconsistent parts: %s", s)
|
||||
}
|
||||
addr := make([]byte, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
u, err := strconv.ParseUint(part, 16, 8)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid hex digits: %s", s)
|
||||
}
|
||||
addr = append(addr, byte(u))
|
||||
}
|
||||
return LinkAddress(addr), nil
|
||||
}
|
||||
|
||||
// GetRandMacAddr returns a mac address that can be used for local virtual devices.
|
||||
func GetRandMacAddr() LinkAddress {
|
||||
mac := make(net.HardwareAddr, LinkAddressSize)
|
||||
rand.Read(mac) // Fill with random data.
|
||||
mac[0] &^= 0x1 // Clear multicast bit.
|
||||
mac[0] |= 0x2 // Set local assignment bit (IEEE802).
|
||||
return LinkAddress(mac)
|
||||
}
|
||||
|
||||
// AddressWithPrefix is an address with its subnet prefix length.
|
||||
//
|
||||
// +stateify savable
|
||||
type AddressWithPrefix struct {
|
||||
// Address is a network address.
|
||||
Address Address
|
||||
|
||||
// PrefixLen is the subnet prefix length.
|
||||
PrefixLen int
|
||||
}
|
||||
|
||||
// String implements the fmt.Stringer interface.
|
||||
func (a AddressWithPrefix) String() string {
|
||||
return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
|
||||
}
|
||||
|
||||
// Subnet converts the address and prefix into a Subnet value and returns it.
|
||||
func (a AddressWithPrefix) Subnet() Subnet {
|
||||
addrLen := a.Address.length
|
||||
if a.PrefixLen <= 0 {
|
||||
return Subnet{
|
||||
address: Address{length: addrLen},
|
||||
mask: AddressMask{length: addrLen},
|
||||
}
|
||||
}
|
||||
if a.PrefixLen >= addrLen*8 {
|
||||
sub := Subnet{
|
||||
address: a.Address,
|
||||
mask: AddressMask{length: addrLen},
|
||||
}
|
||||
for i := 0; i < addrLen; i++ {
|
||||
sub.mask.mask[i] = 0xff
|
||||
}
|
||||
return sub
|
||||
}
|
||||
|
||||
sa := Address{length: addrLen}
|
||||
sm := AddressMask{length: addrLen}
|
||||
n := uint(a.PrefixLen)
|
||||
for i := 0; i < addrLen; i++ {
|
||||
if n >= 8 {
|
||||
sa.addr[i] = a.Address.addr[i]
|
||||
sm.mask[i] = 0xff
|
||||
n -= 8
|
||||
continue
|
||||
}
|
||||
sm.mask[i] = ^byte(0xff >> n)
|
||||
sa.addr[i] = a.Address.addr[i] & sm.mask[i]
|
||||
n = 0
|
||||
}
|
||||
|
||||
// For extra caution, call NewSubnet rather than directly creating the Subnet
|
||||
// value. If that fails it indicates a serious bug in this code, so panic is
|
||||
// in order.
|
||||
s, err := NewSubnet(sa, sm)
|
||||
if err != nil {
|
||||
panic("invalid subnet: " + err.Error())
|
||||
}
|
||||
return s
|
||||
}
|
||||
@@ -1,712 +0,0 @@
|
||||
package tschecksum
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
// checksumGeneric64 is a reference implementation of checksum using 64 bit
|
||||
// arithmetic for use in testing or when an architecture-specific implementation
|
||||
// is not available.
|
||||
func checksumGeneric64(b []byte, initial uint16) uint16 {
|
||||
var ac uint64
|
||||
var carry uint64
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint64(initial)
|
||||
} else {
|
||||
ac = uint64(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 128 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
|
||||
}
|
||||
b = b[128:]
|
||||
}
|
||||
if len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
|
||||
} else {
|
||||
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold64(ac, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
// checksumGeneric32 is a reference implementation of checksum using 32 bit
|
||||
// arithmetic for use in testing or when an architecture-specific implementation
|
||||
// is not available.
|
||||
func checksumGeneric32(b []byte, initial uint16) uint16 {
|
||||
var ac uint32
|
||||
var carry uint32
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint32(initial)
|
||||
} else {
|
||||
ac = uint32(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
|
||||
} else {
|
||||
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold32(ac, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
// checksumGeneric32Alternate is an alternate reference implementation of
|
||||
// checksum using 32 bit arithmetic for use in testing or when an
|
||||
// architecture-specific implementation is not available.
|
||||
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
|
||||
var ac uint32
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint32(initial)
|
||||
} else {
|
||||
ac = uint32(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
|
||||
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(binary.BigEndian.Uint16(b))
|
||||
} else {
|
||||
ac += uint32(binary.LittleEndian.Uint16(b))
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint32(b[0]) << 8
|
||||
} else {
|
||||
ac += uint32(b[0])
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold32(ac, 0)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
// checksumGeneric64Alternate is an alternate reference implementation of
|
||||
// checksum using 64 bit arithmetic for use in testing or when an
|
||||
// architecture-specific implementation is not available.
|
||||
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
|
||||
var ac uint64
|
||||
|
||||
if cpu.IsBigEndian {
|
||||
ac = uint64(initial)
|
||||
} else {
|
||||
ac = uint64(bits.ReverseBytes16(initial))
|
||||
}
|
||||
|
||||
for len(b) >= 64 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
|
||||
}
|
||||
b = b[64:]
|
||||
}
|
||||
if len(b) >= 32 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
|
||||
}
|
||||
b = b[32:]
|
||||
}
|
||||
if len(b) >= 16 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
|
||||
}
|
||||
b = b[16:]
|
||||
}
|
||||
if len(b) >= 8 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
|
||||
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
|
||||
}
|
||||
b = b[8:]
|
||||
}
|
||||
if len(b) >= 4 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint32(b))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint32(b))
|
||||
}
|
||||
b = b[4:]
|
||||
}
|
||||
if len(b) >= 2 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(binary.BigEndian.Uint16(b))
|
||||
} else {
|
||||
ac += uint64(binary.LittleEndian.Uint16(b))
|
||||
}
|
||||
b = b[2:]
|
||||
}
|
||||
if len(b) >= 1 {
|
||||
if cpu.IsBigEndian {
|
||||
ac += uint64(b[0]) << 8
|
||||
} else {
|
||||
ac += uint64(b[0])
|
||||
}
|
||||
}
|
||||
|
||||
folded := ipChecksumFold64(ac, 0)
|
||||
if !cpu.IsBigEndian {
|
||||
folded = bits.ReverseBytes16(folded)
|
||||
}
|
||||
return folded
|
||||
}
|
||||
|
||||
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
|
||||
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
|
||||
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
|
||||
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
|
||||
// no need to save the carry flag
|
||||
sum = (sum >> 16) + (sum & 0xffff) + carry
|
||||
// sum <= 0x1_fffe therefore this is the last fold needed:
|
||||
// if (sum >> 16) > 0 then
|
||||
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
|
||||
// the addition will not overflow
|
||||
// otherwise (sum >> 16) == 0 and sum will be unchanged
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
return uint16(sum)
|
||||
}
|
||||
|
||||
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
|
||||
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
|
||||
// sum <= 0x1_ffff:
|
||||
// 0xffff + 0xffff = 0x1_fffe
|
||||
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
// sum <= 0x1_0000 therefore this is the last fold needed:
|
||||
// if (sum >> 16) > 0 then
|
||||
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
|
||||
// the addition will not overflow
|
||||
// otherwise (sum >> 16) == 0 and sum will be unchanged
|
||||
sum = (sum >> 16) + (sum & 0xffff)
|
||||
return uint16(sum)
|
||||
}
|
||||
|
||||
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
|
||||
sum, carry = initial, carryIn
|
||||
switch len(addr) {
|
||||
case 4: // IPv4
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
|
||||
}
|
||||
case 16: // IPv6
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
|
||||
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
|
||||
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
|
||||
}
|
||||
default:
|
||||
panic("bad addr length")
|
||||
}
|
||||
return sum, carry
|
||||
}
|
||||
|
||||
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
|
||||
sum, carry = initial, carryIn
|
||||
switch len(addr) {
|
||||
case 4: // IPv4
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
|
||||
}
|
||||
case 16: // IPv6
|
||||
if cpu.IsBigEndian {
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
|
||||
} else {
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
|
||||
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
|
||||
}
|
||||
default:
|
||||
panic("bad addr length")
|
||||
}
|
||||
return sum, carry
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
var sum uint64
|
||||
if cpu.IsBigEndian {
|
||||
sum = uint64(totalLen) + uint64(protocol)
|
||||
} else {
|
||||
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
|
||||
}
|
||||
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
|
||||
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
|
||||
|
||||
foldedSum := ipChecksumFold64(sum, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
foldedSum = bits.ReverseBytes16(foldedSum)
|
||||
}
|
||||
return foldedSum
|
||||
}
|
||||
|
||||
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
var sum uint32
|
||||
if cpu.IsBigEndian {
|
||||
sum = uint32(totalLen) + uint32(protocol)
|
||||
} else {
|
||||
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
|
||||
}
|
||||
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
|
||||
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
|
||||
|
||||
foldedSum := ipChecksumFold32(sum, carry)
|
||||
if !cpu.IsBigEndian {
|
||||
foldedSum = bits.ReverseBytes16(foldedSum)
|
||||
}
|
||||
return foldedSum
|
||||
}
|
||||
|
||||
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
|
||||
// dstAddr must be 4 or 16 bytes in length.
|
||||
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
|
||||
if strconv.IntSize < 64 {
|
||||
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
|
||||
}
|
||||
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package tschecksum
|
||||
|
||||
import "golang.org/x/sys/cpu"
|
||||
|
||||
var checksum = checksumAMD64
|
||||
|
||||
// Checksum computes an IP checksum starting with the provided initial value.
|
||||
// The length of data should be at least 128 bytes for best performance. Smaller
|
||||
// buffers will still compute a correct result.
|
||||
func Checksum(data []byte, initial uint16) uint16 {
|
||||
return checksum(data, initial)
|
||||
}
|
||||
|
||||
func init() {
|
||||
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
|
||||
checksum = checksumAVX2
|
||||
return
|
||||
}
|
||||
if cpu.X86.HasSSE2 {
|
||||
checksum = checksumSSE2
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
|
||||
|
||||
package tschecksum
|
||||
|
||||
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
|
||||
//
|
||||
//go:noescape
|
||||
func checksumAVX2(b []byte, initial uint16) uint16
|
||||
|
||||
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
|
||||
//
|
||||
//go:noescape
|
||||
func checksumSSE2(b []byte, initial uint16) uint16
|
||||
|
||||
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
|
||||
//
|
||||
//go:noescape
|
||||
func checksumAMD64(b []byte, initial uint16) uint16
|
||||
@@ -1,851 +0,0 @@
|
||||
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
|
||||
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
|
||||
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
|
||||
|
||||
// func checksumAVX2(b []byte, initial uint16) uint16
|
||||
// Requires: AVX, AVX2, BMI2
|
||||
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
|
||||
MOVWQZX initial+24(FP), AX
|
||||
XCHGB AH, AL
|
||||
MOVQ b_base+0(FP), DX
|
||||
MOVQ b_len+8(FP), BX
|
||||
|
||||
// handle odd length buffers; they are difficult to handle in general
|
||||
TESTQ $0x00000001, BX
|
||||
JZ lengthIsEven
|
||||
MOVBQZX -1(DX)(BX*1), CX
|
||||
DECQ BX
|
||||
ADDQ CX, AX
|
||||
|
||||
lengthIsEven:
|
||||
// handle tiny buffers (<=31 bytes) specially
|
||||
CMPQ BX, $0x1f
|
||||
JGT bufferIsNotTiny
|
||||
XORQ CX, CX
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
|
||||
// shift twice to start because length is guaranteed to be even
|
||||
// n = n >> 2; CF = originalN & 2
|
||||
SHRQ $0x02, BX
|
||||
JNC handleTiny4
|
||||
|
||||
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||
MOVWQZX (DX), CX
|
||||
ADDQ $0x02, DX
|
||||
|
||||
handleTiny4:
|
||||
// n = n >> 1; CF = originalN & 4
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny8
|
||||
|
||||
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||
MOVLQZX (DX), SI
|
||||
ADDQ $0x04, DX
|
||||
|
||||
handleTiny8:
|
||||
// n = n >> 1; CF = originalN & 8
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny16
|
||||
|
||||
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||
MOVQ (DX), DI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleTiny16:
|
||||
// n = n >> 1; CF = originalN & 16
|
||||
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTinyFinish
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
|
||||
handleTinyFinish:
|
||||
// CF should be included from the previous add, so we use ADCQ.
|
||||
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||
// so ADCQ will still produce the correct result.
|
||||
ADCQ CX, AX
|
||||
ADCQ SI, AX
|
||||
ADCQ DI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
bufferIsNotTiny:
|
||||
// skip all SIMD for small buffers
|
||||
CMPQ BX, $0x00000100
|
||||
JGE startSIMD
|
||||
|
||||
// Accumulate carries in this register. It is never expected to overflow.
|
||||
XORQ SI, SI
|
||||
|
||||
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||
// preserve any alignment that may exist for the start of the buffer.
|
||||
MOVQ BX, CX
|
||||
SHRQ $0x03, BX
|
||||
ANDQ $0x07, CX
|
||||
JZ handleRemaining8
|
||||
LEAQ (DX)(BX*8), DI
|
||||
MOVQ -8(DI)(CX*1), DI
|
||||
|
||||
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||
SHLQ $0x03, CX
|
||||
NEGQ CX
|
||||
ADDQ $0x40, CX
|
||||
SHRQ CL, DI
|
||||
ADDQ DI, AX
|
||||
ADCQ $0x00, SI
|
||||
|
||||
handleRemaining8:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining16
|
||||
ADDQ (DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleRemaining16:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining32
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x10, DX
|
||||
|
||||
handleRemaining32:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining64
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x20, DX
|
||||
|
||||
handleRemaining64:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining128
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x40, DX
|
||||
|
||||
handleRemaining128:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemainingComplete
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ 64(DX), AX
|
||||
ADCQ 72(DX), AX
|
||||
ADCQ 80(DX), AX
|
||||
ADCQ 88(DX), AX
|
||||
ADCQ 96(DX), AX
|
||||
ADCQ 104(DX), AX
|
||||
ADCQ 112(DX), AX
|
||||
ADCQ 120(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x80, DX
|
||||
|
||||
handleRemainingComplete:
|
||||
ADDQ SI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
startSIMD:
|
||||
VPXOR Y0, Y0, Y0
|
||||
VPXOR Y1, Y1, Y1
|
||||
VPXOR Y2, Y2, Y2
|
||||
VPXOR Y3, Y3, Y3
|
||||
MOVQ BX, CX
|
||||
|
||||
// Update number of bytes remaining after the loop completes
|
||||
ANDQ $0xff, BX
|
||||
|
||||
// Number of 256 byte iterations
|
||||
SHRQ $0x08, CX
|
||||
JZ smallLoop
|
||||
|
||||
bigLoop:
|
||||
VPMOVZXWD (DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 16(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 32(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 48(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
VPMOVZXWD 64(DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 80(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 96(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 112(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
VPMOVZXWD 128(DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 144(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 160(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 176(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
VPMOVZXWD 192(DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
VPMOVZXWD 208(DX), Y4
|
||||
VPADDD Y4, Y1, Y1
|
||||
VPMOVZXWD 224(DX), Y4
|
||||
VPADDD Y4, Y2, Y2
|
||||
VPMOVZXWD 240(DX), Y4
|
||||
VPADDD Y4, Y3, Y3
|
||||
ADDQ $0x00000100, DX
|
||||
DECQ CX
|
||||
JNZ bigLoop
|
||||
CMPQ BX, $0x10
|
||||
JLT doneSmallLoop
|
||||
|
||||
// now read a single 16 byte unit of data at a time
|
||||
smallLoop:
|
||||
VPMOVZXWD (DX), Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
ADDQ $0x10, DX
|
||||
SUBQ $0x10, BX
|
||||
CMPQ BX, $0x10
|
||||
JGE smallLoop
|
||||
|
||||
doneSmallLoop:
|
||||
CMPQ BX, $0x00
|
||||
JE doneSIMD
|
||||
|
||||
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
|
||||
LEAQ xmmLoadMasks<>+0(SB), CX
|
||||
VMOVDQU -16(DX)(BX*1), X4
|
||||
VPAND -16(CX)(BX*8), X4, X4
|
||||
VPMOVZXWD X4, Y4
|
||||
VPADDD Y4, Y0, Y0
|
||||
|
||||
doneSIMD:
|
||||
// Multi-chain loop is done, combine the accumulators
|
||||
VPADDD Y1, Y0, Y0
|
||||
VPADDD Y2, Y0, Y0
|
||||
VPADDD Y3, Y0, Y0
|
||||
|
||||
// extract the YMM into a pair of XMM and sum them
|
||||
VEXTRACTI128 $0x01, Y0, X1
|
||||
VPADDD X0, X1, X0
|
||||
|
||||
// extract the XMM into GP64
|
||||
VPEXTRQ $0x00, X0, CX
|
||||
VPEXTRQ $0x01, X0, DX
|
||||
|
||||
// no more AVX code, clear upper registers to avoid SSE slowdowns
|
||||
VZEROUPPER
|
||||
ADDQ CX, AX
|
||||
ADCQ DX, AX
|
||||
|
||||
foldAndReturn:
|
||||
// add CF and fold
|
||||
RORXQ $0x20, AX, CX
|
||||
ADCL CX, AX
|
||||
RORXL $0x10, AX, CX
|
||||
ADCW CX, AX
|
||||
ADCW $0x00, AX
|
||||
XCHGB AH, AL
|
||||
MOVW AX, ret+32(FP)
|
||||
RET
|
||||
|
||||
// func checksumSSE2(b []byte, initial uint16) uint16
|
||||
// Requires: SSE2
|
||||
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
|
||||
MOVWQZX initial+24(FP), AX
|
||||
XCHGB AH, AL
|
||||
MOVQ b_base+0(FP), DX
|
||||
MOVQ b_len+8(FP), BX
|
||||
|
||||
// handle odd length buffers; they are difficult to handle in general
|
||||
TESTQ $0x00000001, BX
|
||||
JZ lengthIsEven
|
||||
MOVBQZX -1(DX)(BX*1), CX
|
||||
DECQ BX
|
||||
ADDQ CX, AX
|
||||
|
||||
lengthIsEven:
|
||||
// handle tiny buffers (<=31 bytes) specially
|
||||
CMPQ BX, $0x1f
|
||||
JGT bufferIsNotTiny
|
||||
XORQ CX, CX
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
|
||||
// shift twice to start because length is guaranteed to be even
|
||||
// n = n >> 2; CF = originalN & 2
|
||||
SHRQ $0x02, BX
|
||||
JNC handleTiny4
|
||||
|
||||
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||
MOVWQZX (DX), CX
|
||||
ADDQ $0x02, DX
|
||||
|
||||
handleTiny4:
|
||||
// n = n >> 1; CF = originalN & 4
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny8
|
||||
|
||||
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||
MOVLQZX (DX), SI
|
||||
ADDQ $0x04, DX
|
||||
|
||||
handleTiny8:
|
||||
// n = n >> 1; CF = originalN & 8
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny16
|
||||
|
||||
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||
MOVQ (DX), DI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleTiny16:
|
||||
// n = n >> 1; CF = originalN & 16
|
||||
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTinyFinish
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
|
||||
handleTinyFinish:
|
||||
// CF should be included from the previous add, so we use ADCQ.
|
||||
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||
// so ADCQ will still produce the correct result.
|
||||
ADCQ CX, AX
|
||||
ADCQ SI, AX
|
||||
ADCQ DI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
bufferIsNotTiny:
|
||||
// skip all SIMD for small buffers
|
||||
CMPQ BX, $0x00000100
|
||||
JGE startSIMD
|
||||
|
||||
// Accumulate carries in this register. It is never expected to overflow.
|
||||
XORQ SI, SI
|
||||
|
||||
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||
// preserve any alignment that may exist for the start of the buffer.
|
||||
MOVQ BX, CX
|
||||
SHRQ $0x03, BX
|
||||
ANDQ $0x07, CX
|
||||
JZ handleRemaining8
|
||||
LEAQ (DX)(BX*8), DI
|
||||
MOVQ -8(DI)(CX*1), DI
|
||||
|
||||
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||
SHLQ $0x03, CX
|
||||
NEGQ CX
|
||||
ADDQ $0x40, CX
|
||||
SHRQ CL, DI
|
||||
ADDQ DI, AX
|
||||
ADCQ $0x00, SI
|
||||
|
||||
handleRemaining8:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining16
|
||||
ADDQ (DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleRemaining16:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining32
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x10, DX
|
||||
|
||||
handleRemaining32:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining64
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x20, DX
|
||||
|
||||
handleRemaining64:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining128
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x40, DX
|
||||
|
||||
handleRemaining128:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemainingComplete
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ 64(DX), AX
|
||||
ADCQ 72(DX), AX
|
||||
ADCQ 80(DX), AX
|
||||
ADCQ 88(DX), AX
|
||||
ADCQ 96(DX), AX
|
||||
ADCQ 104(DX), AX
|
||||
ADCQ 112(DX), AX
|
||||
ADCQ 120(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x80, DX
|
||||
|
||||
handleRemainingComplete:
|
||||
ADDQ SI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
startSIMD:
|
||||
PXOR X0, X0
|
||||
PXOR X1, X1
|
||||
PXOR X2, X2
|
||||
PXOR X3, X3
|
||||
PXOR X4, X4
|
||||
MOVQ BX, CX
|
||||
|
||||
// Update number of bytes remaining after the loop completes
|
||||
ANDQ $0xff, BX
|
||||
|
||||
// Number of 256 byte iterations
|
||||
SHRQ $0x08, CX
|
||||
JZ smallLoop
|
||||
|
||||
bigLoop:
|
||||
MOVOU (DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 16(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 32(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 48(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
MOVOU 64(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 80(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 96(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 112(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
MOVOU 128(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 144(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 160(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 176(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
MOVOU 192(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X2
|
||||
MOVOU 208(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X1
|
||||
PADDD X6, X3
|
||||
MOVOU 224(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X2
|
||||
PADDD X6, X0
|
||||
MOVOU 240(DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X3
|
||||
PADDD X6, X1
|
||||
ADDQ $0x00000100, DX
|
||||
DECQ CX
|
||||
JNZ bigLoop
|
||||
CMPQ BX, $0x10
|
||||
JLT doneSmallLoop
|
||||
|
||||
// now read a single 16 byte unit of data at a time
|
||||
smallLoop:
|
||||
MOVOU (DX), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X1
|
||||
ADDQ $0x10, DX
|
||||
SUBQ $0x10, BX
|
||||
CMPQ BX, $0x10
|
||||
JGE smallLoop
|
||||
|
||||
doneSmallLoop:
|
||||
CMPQ BX, $0x00
|
||||
JE doneSIMD
|
||||
|
||||
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
|
||||
LEAQ xmmLoadMasks<>+0(SB), CX
|
||||
MOVOU -16(DX)(BX*1), X5
|
||||
PAND -16(CX)(BX*8), X5
|
||||
MOVOA X5, X6
|
||||
PUNPCKHWL X4, X5
|
||||
PUNPCKLWL X4, X6
|
||||
PADDD X5, X0
|
||||
PADDD X6, X1
|
||||
|
||||
doneSIMD:
|
||||
// Multi-chain loop is done, combine the accumulators
|
||||
PADDD X1, X0
|
||||
PADDD X2, X0
|
||||
PADDD X3, X0
|
||||
|
||||
// extract the XMM into GP64
|
||||
MOVQ X0, CX
|
||||
PSRLDQ $0x08, X0
|
||||
MOVQ X0, DX
|
||||
ADDQ CX, AX
|
||||
ADCQ DX, AX
|
||||
|
||||
foldAndReturn:
|
||||
// add CF and fold
|
||||
MOVL AX, CX
|
||||
ADCQ $0x00, CX
|
||||
SHRQ $0x20, AX
|
||||
ADDQ CX, AX
|
||||
MOVWQZX AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDQ CX, AX
|
||||
MOVW AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDW CX, AX
|
||||
ADCW $0x00, AX
|
||||
XCHGB AH, AL
|
||||
MOVW AX, ret+32(FP)
|
||||
RET
|
||||
|
||||
// func checksumAMD64(b []byte, initial uint16) uint16
|
||||
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
|
||||
MOVWQZX initial+24(FP), AX
|
||||
XCHGB AH, AL
|
||||
MOVQ b_base+0(FP), DX
|
||||
MOVQ b_len+8(FP), BX
|
||||
|
||||
// handle odd length buffers; they are difficult to handle in general
|
||||
TESTQ $0x00000001, BX
|
||||
JZ lengthIsEven
|
||||
MOVBQZX -1(DX)(BX*1), CX
|
||||
DECQ BX
|
||||
ADDQ CX, AX
|
||||
|
||||
lengthIsEven:
|
||||
// handle tiny buffers (<=31 bytes) specially
|
||||
CMPQ BX, $0x1f
|
||||
JGT bufferIsNotTiny
|
||||
XORQ CX, CX
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
|
||||
// shift twice to start because length is guaranteed to be even
|
||||
// n = n >> 2; CF = originalN & 2
|
||||
SHRQ $0x02, BX
|
||||
JNC handleTiny4
|
||||
|
||||
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
|
||||
MOVWQZX (DX), CX
|
||||
ADDQ $0x02, DX
|
||||
|
||||
handleTiny4:
|
||||
// n = n >> 1; CF = originalN & 4
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny8
|
||||
|
||||
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
|
||||
MOVLQZX (DX), SI
|
||||
ADDQ $0x04, DX
|
||||
|
||||
handleTiny8:
|
||||
// n = n >> 1; CF = originalN & 8
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTiny16
|
||||
|
||||
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
|
||||
MOVQ (DX), DI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleTiny16:
|
||||
// n = n >> 1; CF = originalN & 16
|
||||
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
|
||||
SHRQ $0x01, BX
|
||||
JNC handleTinyFinish
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
|
||||
handleTinyFinish:
|
||||
// CF should be included from the previous add, so we use ADCQ.
|
||||
// If we arrived via the JNC above, then CF=0 due to the branch condition,
|
||||
// so ADCQ will still produce the correct result.
|
||||
ADCQ CX, AX
|
||||
ADCQ SI, AX
|
||||
ADCQ DI, AX
|
||||
JMP foldAndReturn
|
||||
|
||||
bufferIsNotTiny:
|
||||
// Number of 256 byte iterations into loop counter
|
||||
MOVQ BX, CX
|
||||
|
||||
// Update number of bytes remaining after the loop completes
|
||||
ANDQ $0xff, BX
|
||||
SHRQ $0x08, CX
|
||||
JZ startCleanup
|
||||
CLC
|
||||
XORQ SI, SI
|
||||
XORQ DI, DI
|
||||
XORQ R8, R8
|
||||
XORQ R9, R9
|
||||
XORQ R10, R10
|
||||
XORQ R11, R11
|
||||
XORQ R12, R12
|
||||
|
||||
bigLoop:
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ 32(DX), DI
|
||||
ADCQ 40(DX), DI
|
||||
ADCQ 48(DX), DI
|
||||
ADCQ 56(DX), DI
|
||||
ADCQ $0x00, R8
|
||||
ADDQ 64(DX), R9
|
||||
ADCQ 72(DX), R9
|
||||
ADCQ 80(DX), R9
|
||||
ADCQ 88(DX), R9
|
||||
ADCQ $0x00, R10
|
||||
ADDQ 96(DX), R11
|
||||
ADCQ 104(DX), R11
|
||||
ADCQ 112(DX), R11
|
||||
ADCQ 120(DX), R11
|
||||
ADCQ $0x00, R12
|
||||
ADDQ 128(DX), AX
|
||||
ADCQ 136(DX), AX
|
||||
ADCQ 144(DX), AX
|
||||
ADCQ 152(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ 160(DX), DI
|
||||
ADCQ 168(DX), DI
|
||||
ADCQ 176(DX), DI
|
||||
ADCQ 184(DX), DI
|
||||
ADCQ $0x00, R8
|
||||
ADDQ 192(DX), R9
|
||||
ADCQ 200(DX), R9
|
||||
ADCQ 208(DX), R9
|
||||
ADCQ 216(DX), R9
|
||||
ADCQ $0x00, R10
|
||||
ADDQ 224(DX), R11
|
||||
ADCQ 232(DX), R11
|
||||
ADCQ 240(DX), R11
|
||||
ADCQ 248(DX), R11
|
||||
ADCQ $0x00, R12
|
||||
ADDQ $0x00000100, DX
|
||||
SUBQ $0x01, CX
|
||||
JNZ bigLoop
|
||||
ADDQ SI, AX
|
||||
ADCQ DI, AX
|
||||
ADCQ R8, AX
|
||||
ADCQ R9, AX
|
||||
ADCQ R10, AX
|
||||
ADCQ R11, AX
|
||||
ADCQ R12, AX
|
||||
|
||||
// accumulate CF (twice, in case the first time overflows)
|
||||
ADCQ $0x00, AX
|
||||
ADCQ $0x00, AX
|
||||
|
||||
startCleanup:
|
||||
// Accumulate carries in this register. It is never expected to overflow.
|
||||
XORQ SI, SI
|
||||
|
||||
// We will perform an overlapped read for buffers with length not a multiple of 8.
|
||||
// Overlapped in this context means some memory will be read twice, but a shift will
|
||||
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
|
||||
// preserve any alignment that may exist for the start of the buffer.
|
||||
MOVQ BX, CX
|
||||
SHRQ $0x03, BX
|
||||
ANDQ $0x07, CX
|
||||
JZ handleRemaining8
|
||||
LEAQ (DX)(BX*8), DI
|
||||
MOVQ -8(DI)(CX*1), DI
|
||||
|
||||
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
|
||||
SHLQ $0x03, CX
|
||||
NEGQ CX
|
||||
ADDQ $0x40, CX
|
||||
SHRQ CL, DI
|
||||
ADDQ DI, AX
|
||||
ADCQ $0x00, SI
|
||||
|
||||
handleRemaining8:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining16
|
||||
ADDQ (DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x08, DX
|
||||
|
||||
handleRemaining16:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining32
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x10, DX
|
||||
|
||||
handleRemaining32:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining64
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x20, DX
|
||||
|
||||
handleRemaining64:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemaining128
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x40, DX
|
||||
|
||||
handleRemaining128:
|
||||
SHRQ $0x01, BX
|
||||
JNC handleRemainingComplete
|
||||
ADDQ (DX), AX
|
||||
ADCQ 8(DX), AX
|
||||
ADCQ 16(DX), AX
|
||||
ADCQ 24(DX), AX
|
||||
ADCQ 32(DX), AX
|
||||
ADCQ 40(DX), AX
|
||||
ADCQ 48(DX), AX
|
||||
ADCQ 56(DX), AX
|
||||
ADCQ 64(DX), AX
|
||||
ADCQ 72(DX), AX
|
||||
ADCQ 80(DX), AX
|
||||
ADCQ 88(DX), AX
|
||||
ADCQ 96(DX), AX
|
||||
ADCQ 104(DX), AX
|
||||
ADCQ 112(DX), AX
|
||||
ADCQ 120(DX), AX
|
||||
ADCQ $0x00, SI
|
||||
ADDQ $0x80, DX
|
||||
|
||||
handleRemainingComplete:
|
||||
ADDQ SI, AX
|
||||
|
||||
foldAndReturn:
|
||||
// add CF and fold
|
||||
MOVL AX, CX
|
||||
ADCQ $0x00, CX
|
||||
SHRQ $0x20, AX
|
||||
ADDQ CX, AX
|
||||
MOVWQZX AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDQ CX, AX
|
||||
MOVW AX, CX
|
||||
SHRQ $0x10, AX
|
||||
ADDW CX, AX
|
||||
ADCW $0x00, AX
|
||||
XCHGB AH, AL
|
||||
MOVW AX, ret+32(FP)
|
||||
RET
|
||||
@@ -1,15 +0,0 @@
|
||||
// This file contains IP checksum algorithms that are not specific to any
|
||||
// architecture and don't use hardware acceleration.
|
||||
|
||||
//go:build !amd64
|
||||
|
||||
package tschecksum
|
||||
|
||||
import "strconv"
|
||||
|
||||
func Checksum(data []byte, initial uint16) uint16 {
|
||||
if strconv.IntSize < 64 {
|
||||
return checksumGeneric32(data, initial)
|
||||
}
|
||||
return checksumGeneric64(data, initial)
|
||||
}
|
||||
@@ -1,578 +0,0 @@
|
||||
//go:build ignore
|
||||
|
||||
//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/bits"
|
||||
|
||||
"github.com/mmcloughlin/avo/operand"
|
||||
"github.com/mmcloughlin/avo/reg"
|
||||
)
|
||||
|
||||
const checksumSignature = "func(b []byte, initial uint16) uint16"
|
||||
|
||||
func loadParams() (accum, buf, n reg.GPVirtual) {
|
||||
accum, buf, n = GP64(), GP64(), GP64()
|
||||
Load(Param("initial"), accum)
|
||||
XCHGB(accum.As8H(), accum.As8L())
|
||||
Load(Param("b").Base(), buf)
|
||||
Load(Param("b").Len(), n)
|
||||
return
|
||||
}
|
||||
|
||||
type simdStrategy int
|
||||
|
||||
const (
|
||||
sse2 = iota
|
||||
avx2
|
||||
)
|
||||
|
||||
const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes.
|
||||
|
||||
func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) {
|
||||
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
|
||||
Pragma("noescape")
|
||||
Doc(doc)
|
||||
|
||||
accum64, buf, n := loadParams()
|
||||
|
||||
handleOddLength(n, buf, accum64)
|
||||
// no chance of overflow because accum64 was initialized by a uint16 and
|
||||
// handleOddLength adds at most a uint8
|
||||
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
|
||||
Label("bufferIsNotTiny")
|
||||
|
||||
const simdReadSize = 16
|
||||
|
||||
if minSIMDSize > tinyBufferSize {
|
||||
Comment("skip all SIMD for small buffers")
|
||||
if minSIMDSize <= math.MaxUint8 {
|
||||
CMPQ(n, operand.U8(minSIMDSize))
|
||||
} else {
|
||||
CMPQ(n, operand.U32(minSIMDSize))
|
||||
}
|
||||
JGE(operand.LabelRef("startSIMD"))
|
||||
|
||||
handleRemaining(n, buf, accum64, minSIMDSize-1)
|
||||
JMP(operand.LabelRef("foldAndReturn"))
|
||||
}
|
||||
|
||||
Label("startSIMD")
|
||||
|
||||
// chains is the number of accumulators to use. This improves speed via
|
||||
// reduced data dependency. We combine the accumulators once when the big
|
||||
// loop is complete.
|
||||
simdAccumulate := make([]reg.VecVirtual, chains)
|
||||
for i := range simdAccumulate {
|
||||
switch strategy {
|
||||
case sse2:
|
||||
simdAccumulate[i] = XMM()
|
||||
PXOR(simdAccumulate[i], simdAccumulate[i])
|
||||
case avx2:
|
||||
simdAccumulate[i] = YMM()
|
||||
VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i])
|
||||
}
|
||||
}
|
||||
var zero reg.VecVirtual
|
||||
if strategy == sse2 {
|
||||
zero = XMM()
|
||||
PXOR(zero, zero)
|
||||
}
|
||||
|
||||
// Number of loads per big loop
|
||||
const unroll = 16
|
||||
// Number of bytes
|
||||
loopSize := uint64(simdReadSize * unroll)
|
||||
if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 {
|
||||
panic("loopSize is not a power of 2")
|
||||
}
|
||||
loopCount := GP64()
|
||||
|
||||
MOVQ(n, loopCount)
|
||||
Comment("Update number of bytes remaining after the loop completes")
|
||||
ANDQ(operand.Imm(loopSize-1), n)
|
||||
Comment(fmt.Sprintf("Number of %d byte iterations", loopSize))
|
||||
SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount)
|
||||
JZ(operand.LabelRef("smallLoop"))
|
||||
Label("bigLoop")
|
||||
for i := 0; i < unroll; i++ {
|
||||
chain := i % chains
|
||||
switch strategy {
|
||||
case sse2:
|
||||
sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains])
|
||||
case avx2:
|
||||
avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain])
|
||||
}
|
||||
}
|
||||
ADDQ(operand.U32(loopSize), buf)
|
||||
DECQ(loopCount)
|
||||
JNZ(operand.LabelRef("bigLoop"))
|
||||
|
||||
Label("bigCleanup")
|
||||
|
||||
CMPQ(n, operand.Imm(uint64(simdReadSize)))
|
||||
JLT(operand.LabelRef("doneSmallLoop"))
|
||||
|
||||
Commentf("now read a single %d byte unit of data at a time", simdReadSize)
|
||||
Label("smallLoop")
|
||||
|
||||
switch strategy {
|
||||
case sse2:
|
||||
sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1])
|
||||
case avx2:
|
||||
avx2AccumulateStep(0, buf, simdAccumulate[0])
|
||||
}
|
||||
ADDQ(operand.Imm(uint64(simdReadSize)), buf)
|
||||
SUBQ(operand.Imm(uint64(simdReadSize)), n)
|
||||
CMPQ(n, operand.Imm(uint64(simdReadSize)))
|
||||
JGE(operand.LabelRef("smallLoop"))
|
||||
|
||||
Label("doneSmallLoop")
|
||||
CMPQ(n, operand.Imm(0))
|
||||
JE(operand.LabelRef("doneSIMD"))
|
||||
|
||||
Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1)
|
||||
|
||||
maskDataPtr := GP64()
|
||||
LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr)
|
||||
dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize}
|
||||
// scale 8 is only correct here because n is guaranteed to be even and we
|
||||
// do not generate masks for odd lengths
|
||||
maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16}
|
||||
remainder := XMM()
|
||||
|
||||
switch strategy {
|
||||
case sse2:
|
||||
MOVOU(dataAddr, remainder)
|
||||
PAND(maskAddr, remainder)
|
||||
low := XMM()
|
||||
MOVOA(remainder, low)
|
||||
PUNPCKHWL(zero, remainder)
|
||||
PUNPCKLWL(zero, low)
|
||||
PADDD(remainder, simdAccumulate[0])
|
||||
PADDD(low, simdAccumulate[1])
|
||||
case avx2:
|
||||
// Note: this is very similar to the sse2 path but MOVOU has a massive
|
||||
// performance hit if used here, presumably due to switching between SSE
|
||||
// and AVX2 modes.
|
||||
VMOVDQU(dataAddr, remainder)
|
||||
VPAND(maskAddr, remainder, remainder)
|
||||
|
||||
temp := YMM()
|
||||
VPMOVZXWD(remainder, temp)
|
||||
VPADDD(temp, simdAccumulate[0], simdAccumulate[0])
|
||||
}
|
||||
|
||||
Label("doneSIMD")
|
||||
|
||||
Comment("Multi-chain loop is done, combine the accumulators")
|
||||
for i := range simdAccumulate {
|
||||
if i == 0 {
|
||||
continue
|
||||
}
|
||||
switch strategy {
|
||||
case sse2:
|
||||
PADDD(simdAccumulate[i], simdAccumulate[0])
|
||||
case avx2:
|
||||
VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0])
|
||||
}
|
||||
}
|
||||
|
||||
if strategy == avx2 {
|
||||
Comment("extract the YMM into a pair of XMM and sum them")
|
||||
tmp := YMM()
|
||||
VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX())
|
||||
|
||||
xAccumulate := XMM()
|
||||
VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate)
|
||||
simdAccumulate = []reg.VecVirtual{xAccumulate}
|
||||
}
|
||||
|
||||
Comment("extract the XMM into GP64")
|
||||
low, high := GP64(), GP64()
|
||||
switch strategy {
|
||||
case sse2:
|
||||
MOVQ(simdAccumulate[0], low)
|
||||
PSRLDQ(operand.Imm(8), simdAccumulate[0])
|
||||
MOVQ(simdAccumulate[0], high)
|
||||
case avx2:
|
||||
VPEXTRQ(operand.Imm(0), simdAccumulate[0], low)
|
||||
VPEXTRQ(operand.Imm(1), simdAccumulate[0], high)
|
||||
|
||||
Comment("no more AVX code, clear upper registers to avoid SSE slowdowns")
|
||||
VZEROUPPER()
|
||||
}
|
||||
ADDQ(low, accum64)
|
||||
ADCQ(high, accum64)
|
||||
Label("foldAndReturn")
|
||||
foldWithCF(accum64, strategy == avx2)
|
||||
XCHGB(accum64.As8H(), accum64.As8L())
|
||||
Store(accum64.As16(), ReturnIndex(0))
|
||||
RET()
|
||||
}
|
||||
|
||||
// handleOddLength generates instructions to incorporate the last byte into
|
||||
// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to
|
||||
// handle that if overflow is possible.
|
||||
func handleOddLength(n, buf, accum64 reg.GPVirtual) {
|
||||
Comment("handle odd length buffers; they are difficult to handle in general")
|
||||
TESTQ(operand.U32(1), n)
|
||||
JZ(operand.LabelRef("lengthIsEven"))
|
||||
|
||||
tmp := GP64()
|
||||
MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp)
|
||||
DECQ(n)
|
||||
ADDQ(tmp, accum64)
|
||||
|
||||
Label("lengthIsEven")
|
||||
}
|
||||
|
||||
func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) {
|
||||
high, low := XMM(), XMM()
|
||||
MOVOU(operand.Mem{Disp: offset, Base: buf}, high)
|
||||
MOVOA(high, low)
|
||||
PUNPCKHWL(zero, high)
|
||||
PUNPCKLWL(zero, low)
|
||||
PADDD(high, accumulate1)
|
||||
PADDD(low, accumulate2)
|
||||
}
|
||||
|
||||
func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) {
|
||||
tmp := YMM()
|
||||
VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp)
|
||||
VPADDD(tmp, accumulate, accumulate)
|
||||
}
|
||||
|
||||
func generateAMD64Checksum(name, doc string) {
|
||||
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
|
||||
Pragma("noescape")
|
||||
Doc(doc)
|
||||
|
||||
accum64, buf, n := loadParams()
|
||||
|
||||
handleOddLength(n, buf, accum64)
|
||||
// no chance of overflow because accum64 was initialized by a uint16 and
|
||||
// handleOddLength adds at most a uint8
|
||||
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
|
||||
Label("bufferIsNotTiny")
|
||||
|
||||
const (
|
||||
// numChains is the number of accumulators and carry counters to use.
|
||||
// This improves speed via reduced data dependency. We combine the
|
||||
// accumulators and carry counters once when the loop is complete.
|
||||
numChains = 4
|
||||
unroll = 32 // The number of 64-bit reads to perform per iteration of the loop.
|
||||
loopSize = 8 * unroll // The number of bytes read per iteration of the loop.
|
||||
)
|
||||
if bits.Len(loopSize) != bits.Len(loopSize-1)+1 {
|
||||
panic("loopSize is not a power of 2")
|
||||
}
|
||||
loopCount := GP64()
|
||||
|
||||
Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize))
|
||||
MOVQ(n, loopCount)
|
||||
Comment("Update number of bytes remaining after the loop completes")
|
||||
ANDQ(operand.Imm(loopSize-1), n)
|
||||
SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount)
|
||||
JZ(operand.LabelRef("startCleanup"))
|
||||
CLC()
|
||||
|
||||
chains := make([]struct {
|
||||
accum reg.GPVirtual
|
||||
carries reg.GPVirtual
|
||||
}, numChains)
|
||||
for i := range chains {
|
||||
if i == 0 {
|
||||
chains[i].accum = accum64
|
||||
} else {
|
||||
chains[i].accum = GP64()
|
||||
XORQ(chains[i].accum, chains[i].accum)
|
||||
}
|
||||
chains[i].carries = GP64()
|
||||
XORQ(chains[i].carries, chains[i].carries)
|
||||
}
|
||||
|
||||
Label("bigLoop")
|
||||
|
||||
var curChain int
|
||||
for i := 0; i < unroll; i++ {
|
||||
// It is significantly faster to use a ADCX/ADOX pair instead of plain
|
||||
// ADC, which results in two dependency chains, however those require
|
||||
// ADX support, which was added after AVX2. If AVX2 is available, that's
|
||||
// even better than ADCX/ADOX.
|
||||
//
|
||||
// However, multiple dependency chains using multiple accumulators and
|
||||
// occasionally storing CF into temporary counters seems to work almost
|
||||
// as well.
|
||||
addr := operand.Mem{Disp: i * 8, Base: buf}
|
||||
|
||||
if i%4 == 0 {
|
||||
if i > 0 {
|
||||
ADCQ(operand.Imm(0), chains[curChain].carries)
|
||||
curChain = (curChain + 1) % len(chains)
|
||||
}
|
||||
ADDQ(addr, chains[curChain].accum)
|
||||
} else {
|
||||
ADCQ(addr, chains[curChain].accum)
|
||||
}
|
||||
}
|
||||
ADCQ(operand.Imm(0), chains[curChain].carries)
|
||||
ADDQ(operand.U32(loopSize), buf)
|
||||
SUBQ(operand.Imm(1), loopCount)
|
||||
JNZ(operand.LabelRef("bigLoop"))
|
||||
for i := range chains {
|
||||
if i == 0 {
|
||||
ADDQ(chains[i].carries, accum64)
|
||||
continue
|
||||
}
|
||||
ADCQ(chains[i].accum, accum64)
|
||||
ADCQ(chains[i].carries, accum64)
|
||||
}
|
||||
|
||||
accumulateCF(accum64)
|
||||
|
||||
Label("startCleanup")
|
||||
handleRemaining(n, buf, accum64, loopSize-1)
|
||||
Label("foldAndReturn")
|
||||
foldWithCF(accum64, false)
|
||||
|
||||
XCHGB(accum64.As8H(), accum64.As8L())
|
||||
Store(accum64.As16(), ReturnIndex(0))
|
||||
RET()
|
||||
}
|
||||
|
||||
// handleTinyBuffers computes checksums if the buffer length (the n parameter)
|
||||
// is less than 32. After computing the checksum, a jump to returnLabel will
|
||||
// be executed. Otherwise, if the buffer length is at least 32, nothing will be
|
||||
// modified; a jump to continueLabel will be executed instead.
|
||||
//
|
||||
// When jumping to returnLabel, CF may be set and must be accommodated e.g.
|
||||
// using foldWithCF or accumulateCF.
|
||||
//
|
||||
// Anecdotally, this appears to be faster than attempting to coordinate an
|
||||
// overlapped read (which would also require special handling for buffers
|
||||
// smaller than 8).
|
||||
func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) {
|
||||
Comment("handle tiny buffers (<=31 bytes) specially")
|
||||
CMPQ(n, operand.Imm(tinyBufferSize))
|
||||
JGT(continueLabel)
|
||||
|
||||
tmp2, tmp4, tmp8 := GP64(), GP64(), GP64()
|
||||
XORQ(tmp2, tmp2)
|
||||
XORQ(tmp4, tmp4)
|
||||
XORQ(tmp8, tmp8)
|
||||
|
||||
Comment("shift twice to start because length is guaranteed to be even",
|
||||
"n = n >> 2; CF = originalN & 2")
|
||||
SHRQ(operand.Imm(2), n)
|
||||
JNC(operand.LabelRef("handleTiny4"))
|
||||
Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]")
|
||||
MOVWQZX(operand.Mem{Base: buf}, tmp2)
|
||||
ADDQ(operand.Imm(2), buf)
|
||||
|
||||
Label("handleTiny4")
|
||||
Comment("n = n >> 1; CF = originalN & 4")
|
||||
SHRQ(operand.Imm(1), n)
|
||||
JNC(operand.LabelRef("handleTiny8"))
|
||||
Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]")
|
||||
MOVLQZX(operand.Mem{Base: buf}, tmp4)
|
||||
ADDQ(operand.Imm(4), buf)
|
||||
|
||||
Label("handleTiny8")
|
||||
Comment("n = n >> 1; CF = originalN & 8")
|
||||
SHRQ(operand.Imm(1), n)
|
||||
JNC(operand.LabelRef("handleTiny16"))
|
||||
Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]")
|
||||
MOVQ(operand.Mem{Base: buf}, tmp8)
|
||||
ADDQ(operand.Imm(8), buf)
|
||||
|
||||
Label("handleTiny16")
|
||||
Comment("n = n >> 1; CF = originalN & 16",
|
||||
"n == 0 now, otherwise we would have branched after comparing with tinyBufferSize")
|
||||
SHRQ(operand.Imm(1), n)
|
||||
JNC(operand.LabelRef("handleTinyFinish"))
|
||||
ADDQ(operand.Mem{Base: buf}, accum)
|
||||
ADCQ(operand.Mem{Base: buf, Disp: 8}, accum)
|
||||
|
||||
Label("handleTinyFinish")
|
||||
Comment("CF should be included from the previous add, so we use ADCQ.",
|
||||
"If we arrived via the JNC above, then CF=0 due to the branch condition,",
|
||||
"so ADCQ will still produce the correct result.")
|
||||
ADCQ(tmp2, accum)
|
||||
ADCQ(tmp4, accum)
|
||||
ADCQ(tmp8, accum)
|
||||
|
||||
JMP(returnLabel)
|
||||
}
|
||||
|
||||
// handleRemaining generates a series of conditional unrolled additions,
|
||||
// starting with 8 bytes long and doubling each time until the length reaches
|
||||
// max. This is the reverse order of what may be intuitive, but makes the branch
|
||||
// conditions convenient to compute: perform one right shift each time and test
|
||||
// against CF.
|
||||
//
|
||||
// When done, CF may be set and must be accommodated e.g., using foldWithCF or
|
||||
// accumulateCF.
|
||||
//
|
||||
// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer
|
||||
// will be performed, overlapping with data that will be read later. The
|
||||
// duplicate data will be shifted off.
|
||||
//
|
||||
// The original buffer length must have been at least 8 bytes long, even if
|
||||
// n < 8, otherwise this will access memory before the start of the buffer,
|
||||
// which may be unsafe.
|
||||
func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) {
|
||||
Comment("Accumulate carries in this register. It is never expected to overflow.")
|
||||
carries := GP64()
|
||||
XORQ(carries, carries)
|
||||
|
||||
Comment("We will perform an overlapped read for buffers with length not a multiple of 8.",
|
||||
"Overlapped in this context means some memory will be read twice, but a shift will",
|
||||
"eliminate the duplicated data. This extra read is performed at the end of the buffer to",
|
||||
"preserve any alignment that may exist for the start of the buffer.")
|
||||
leftover := reg.RCX
|
||||
MOVQ(n, leftover)
|
||||
SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining
|
||||
ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end
|
||||
JZ(operand.LabelRef("handleRemaining8"))
|
||||
endBuf := GP64()
|
||||
// endBuf is the position near the end of the buffer that is just past the
|
||||
// last multiple of 8: (buf + len(buf)) & ^0x7
|
||||
LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf)
|
||||
|
||||
overlapRead := GP64()
|
||||
// equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)])
|
||||
MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead)
|
||||
|
||||
Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)")
|
||||
SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8
|
||||
NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression
|
||||
ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8)
|
||||
SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL
|
||||
|
||||
ADDQ(overlapRead, accum64)
|
||||
ADCQ(operand.Imm(0), carries)
|
||||
|
||||
for curBytes := 8; curBytes <= max; curBytes *= 2 {
|
||||
Label(fmt.Sprintf("handleRemaining%d", curBytes))
|
||||
SHRQ(operand.Imm(1), n)
|
||||
if curBytes*2 <= max {
|
||||
JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
|
||||
} else {
|
||||
JNC(operand.LabelRef("handleRemainingComplete"))
|
||||
}
|
||||
|
||||
numLoads := curBytes / 8
|
||||
for i := 0; i < numLoads; i++ {
|
||||
addr := operand.Mem{Base: buf, Disp: i * 8}
|
||||
// It is possible to add the multiple dependency chains trick here
|
||||
// that generateAMD64Checksum uses but anecdotally it does not
|
||||
// appear to outweigh the cost.
|
||||
if i == 0 {
|
||||
ADDQ(addr, accum64)
|
||||
continue
|
||||
}
|
||||
ADCQ(addr, accum64)
|
||||
}
|
||||
ADCQ(operand.Imm(0), carries)
|
||||
|
||||
if curBytes > math.MaxUint8 {
|
||||
ADDQ(operand.U32(uint64(curBytes)), buf)
|
||||
} else {
|
||||
ADDQ(operand.U8(uint64(curBytes)), buf)
|
||||
}
|
||||
if curBytes*2 >= max {
|
||||
continue
|
||||
}
|
||||
JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
|
||||
}
|
||||
Label("handleRemainingComplete")
|
||||
ADDQ(carries, accum64)
|
||||
}
|
||||
|
||||
func accumulateCF(accum64 reg.GPVirtual) {
|
||||
Comment("accumulate CF (twice, in case the first time overflows)")
|
||||
// accum64 += CF
|
||||
ADCQ(operand.Imm(0), accum64)
|
||||
// accum64 += CF again if the previous add overflowed. The previous add was
|
||||
// 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can
|
||||
// never overflow.
|
||||
ADCQ(operand.Imm(0), accum64)
|
||||
}
|
||||
|
||||
// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value
|
||||
// according to ones-complement arithmetic. BMI2 instructions will be used if
|
||||
// allowBMI2 is true (requires fewer instructions).
|
||||
func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) {
|
||||
Comment("add CF and fold")
|
||||
|
||||
// CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff
|
||||
|
||||
tmp := GP64()
|
||||
if allowBMI2 {
|
||||
// effectively, tmp = accum >> 32 (technically, this is a rotate)
|
||||
RORXQ(operand.Imm(32), accum, tmp)
|
||||
// accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set
|
||||
ADCL(tmp.As32(), accum.As32())
|
||||
// effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate)
|
||||
RORXL(operand.Imm(16), accum.As32(), tmp.As32())
|
||||
// accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set
|
||||
ADCW(tmp.As16(), accum.As16())
|
||||
} else {
|
||||
// tmp = uint32(accum); max value 0xffff_ffff
|
||||
// MOVL clears the upper 32 bits of a GP64 so this is equivalent to the
|
||||
// non-existent MOVLQZX.
|
||||
MOVL(accum.As32(), tmp.As32())
|
||||
// tmp += CF; max value 0x1_0000_0000, CF unset
|
||||
ADCQ(operand.Imm(0), tmp)
|
||||
// accum = accum >> 32; max value 0xffff_ffff
|
||||
SHRQ(operand.Imm(32), accum)
|
||||
// accum = accum + tmp; max value 0x1_ffff_ffff + CF unset
|
||||
ADDQ(tmp, accum)
|
||||
// tmp = uint16(accum); max value 0xffff
|
||||
MOVWQZX(accum.As16(), tmp)
|
||||
// accum = accum >> 16; max value 0x1_ffff
|
||||
SHRQ(operand.Imm(16), accum)
|
||||
// accum = accum + tmp; max value 0x2_fffe + CF unset
|
||||
ADDQ(tmp, accum)
|
||||
// tmp as uint16 = uint16(accum); max value 0xffff
|
||||
MOVW(accum.As16(), tmp.As16())
|
||||
// accum = accum >> 16; max value 0x2
|
||||
SHRQ(operand.Imm(16), accum)
|
||||
// accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set
|
||||
ADDW(tmp.As16(), accum.As16())
|
||||
}
|
||||
// accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe
|
||||
ADCW(operand.Imm(0), accum.As16())
|
||||
}
|
||||
|
||||
func generateLoadMasks() {
|
||||
var offset int
|
||||
// xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14
|
||||
GLOBL("xmmLoadMasks", RODATA|NOPTR)
|
||||
|
||||
for n := 2; n < 16; n += 2 {
|
||||
var pattern [16]byte
|
||||
for i := 0; i < len(pattern); i++ {
|
||||
if i < len(pattern)-n {
|
||||
pattern[i] = 0
|
||||
continue
|
||||
}
|
||||
pattern[i] = 0xff
|
||||
}
|
||||
DATA(offset, operand.String(pattern[:]))
|
||||
offset += len(pattern)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
generateLoadMasks()
|
||||
generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2)
|
||||
generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2)
|
||||
generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions")
|
||||
Generate()
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/go-ole/go-ole"
|
||||
"github.com/go-ole/go-ole/oleutil"
|
||||
"github.com/scjalliance/comshim"
|
||||
)
|
||||
|
||||
// Firewall related API constants.
|
||||
@@ -249,10 +250,7 @@ func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
|
||||
// then:
|
||||
// dispatch firewallAPIRelease(u, fwp)
|
||||
func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
|
||||
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("Failed to initialize COM: %s", err)
|
||||
}
|
||||
comshim.Add(1)
|
||||
|
||||
unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
|
||||
if err != nil {
|
||||
@@ -272,5 +270,5 @@ func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
|
||||
func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
|
||||
fwp.Release()
|
||||
u.Release()
|
||||
ole.CoUninitialize()
|
||||
comshim.Done()
|
||||
}
|
||||
|
||||
@@ -385,25 +385,3 @@ func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []st
|
||||
func (luid LUID) FlushDNS(family AddressFamily) error {
|
||||
return luid.SetDNS(family, nil, nil)
|
||||
}
|
||||
|
||||
func (luid LUID) DisableDNSRegistration() error {
|
||||
guid, err := luid.GUID()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsInterfaceSettings := &DnsInterfaceSettings{
|
||||
Version: DnsInterfaceSettingsVersion1,
|
||||
Flags: DnsInterfaceSettingsFlagRegistrationEnabled,
|
||||
RegistrationEnabled: 0,
|
||||
}
|
||||
|
||||
// For >= Windows 10 1809
|
||||
err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
|
||||
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
|
||||
return err
|
||||
}
|
||||
|
||||
// For < Windows 10 1809
|
||||
return luid.fallbackDisableDNSRegistration()
|
||||
}
|
||||
|
||||
@@ -51,11 +51,10 @@ func runNetsh(cmds []string) error {
|
||||
}
|
||||
|
||||
const (
|
||||
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no"
|
||||
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no"
|
||||
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
|
||||
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
|
||||
netshCmdTemplateDisableRegistration = "interface ipv6 set dnsservers name=%d register=none"
|
||||
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
|
||||
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
|
||||
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
|
||||
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
|
||||
)
|
||||
|
||||
func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
|
||||
@@ -107,13 +106,3 @@ func (luid LUID) fallbackSetDNSDomain(domain string) error {
|
||||
key.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (luid LUID) fallbackDisableDNSRegistration() error {
|
||||
// the DNS registration setting is shared for both IPv4 and IPv6
|
||||
ipif, err := luid.IPInterface(windows.AF_INET)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd := fmt.Sprintf(netshCmdTemplateDisableRegistration, ipif.InterfaceIndex)
|
||||
return runNetsh([]string{cmd})
|
||||
}
|
||||
|
||||
149
lwip.go
Normal file
149
lwip.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build with_lwip
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
lwip "github.com/sagernet/go-tun2socks/core"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat"
|
||||
)
|
||||
|
||||
type LWIP struct {
|
||||
ctx context.Context
|
||||
tun Tun
|
||||
tunMtu uint32
|
||||
udpTimeout int64
|
||||
handler Handler
|
||||
stack lwip.LWIPStack
|
||||
udpNat *udpnat.Service[netip.AddrPort]
|
||||
}
|
||||
|
||||
func NewLWIP(
|
||||
options StackOptions,
|
||||
) (Stack, error) {
|
||||
return &LWIP{
|
||||
ctx: options.Context,
|
||||
tun: options.Tun,
|
||||
tunMtu: options.MTU,
|
||||
handler: options.Handler,
|
||||
stack: lwip.NewLWIPStack(),
|
||||
udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *LWIP) Start() error {
|
||||
lwip.RegisterTCPConnHandler(l)
|
||||
lwip.RegisterUDPConnHandler(l)
|
||||
lwip.RegisterOutputFn(l.tun.Write)
|
||||
go l.loopIn()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LWIP) loopIn() {
|
||||
if winTun, isWintun := l.tun.(WinTun); isWintun {
|
||||
l.loopInWintun(winTun)
|
||||
return
|
||||
}
|
||||
mtu := int(l.tunMtu) + PacketOffset
|
||||
_buffer := buf.StackNewSize(mtu)
|
||||
defer common.KeepAlive(_buffer)
|
||||
buffer := common.Dup(_buffer)
|
||||
defer buffer.Release()
|
||||
data := buffer.FreeBytes()
|
||||
for {
|
||||
n, err := l.tun.Read(data)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = l.stack.Write(data[PacketOffset:n])
|
||||
if err != nil {
|
||||
if err.Error() == "stack closed" {
|
||||
return
|
||||
}
|
||||
l.handler.NewError(context.Background(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LWIP) loopInWintun(tun WinTun) {
|
||||
for {
|
||||
packet, release, err := tun.ReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = l.stack.Write(packet)
|
||||
release()
|
||||
if err != nil {
|
||||
if err.Error() == "stack closed" {
|
||||
return
|
||||
}
|
||||
l.handler.NewError(context.Background(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *LWIP) Close() error {
|
||||
lwip.RegisterTCPConnHandler(nil)
|
||||
lwip.RegisterUDPConnHandler(nil)
|
||||
lwip.RegisterOutputFn(func(bytes []byte) (int, error) {
|
||||
return 0, os.ErrClosed
|
||||
})
|
||||
return l.stack.Close()
|
||||
}
|
||||
|
||||
func (l *LWIP) Handle(conn net.Conn) error {
|
||||
lAddr := conn.LocalAddr()
|
||||
rAddr := conn.RemoteAddr()
|
||||
if lAddr == nil || rAddr == nil {
|
||||
conn.Close()
|
||||
return nil
|
||||
}
|
||||
go func() {
|
||||
var metadata M.Metadata
|
||||
metadata.Source = M.SocksaddrFromNet(lAddr)
|
||||
metadata.Destination = M.SocksaddrFromNet(rAddr)
|
||||
hErr := l.handler.NewConnection(l.ctx, conn, metadata)
|
||||
if hErr != nil {
|
||||
conn.(lwip.TCPConn).Abort()
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LWIP) ReceiveTo(conn lwip.UDPConn, data []byte, addr M.Socksaddr) error {
|
||||
var upstreamMetadata M.Metadata
|
||||
upstreamMetadata.Source = conn.LocalAddr()
|
||||
upstreamMetadata.Destination = addr
|
||||
|
||||
l.udpNat.NewPacket(
|
||||
l.ctx,
|
||||
upstreamMetadata.Source.AddrPort(),
|
||||
buf.As(data).ToOwned(),
|
||||
upstreamMetadata,
|
||||
func(natConn N.PacketConn) N.PacketWriter {
|
||||
return &LWIPUDPBackWriter{conn}
|
||||
},
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
type LWIPUDPBackWriter struct {
|
||||
conn lwip.UDPConn
|
||||
}
|
||||
|
||||
func (w *LWIPUDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
defer buffer.Release()
|
||||
return common.Error(w.conn.WriteFrom(buffer.Bytes(), destination))
|
||||
}
|
||||
|
||||
func (w *LWIPUDPBackWriter) Close() error {
|
||||
return w.conn.Close()
|
||||
}
|
||||
11
lwip_stub.go
Normal file
11
lwip_stub.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !with_lwip
|
||||
|
||||
package tun
|
||||
|
||||
import E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
func NewLWIP(
|
||||
options StackOptions,
|
||||
) (Stack, error) {
|
||||
return nil, E.New(`LWIP is not included in this build, rebuild with -tags with_lwip`)
|
||||
}
|
||||
17
monitor.go
17
monitor.go
@@ -1,7 +1,8 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"net/netip"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
@@ -9,23 +10,28 @@ import (
|
||||
var ErrNoRoute = E.New("no route to internet")
|
||||
|
||||
type (
|
||||
NetworkUpdateCallback = func()
|
||||
DefaultInterfaceUpdateCallback = func(defaultInterface *control.Interface, flags int)
|
||||
NetworkUpdateCallback = func() error
|
||||
DefaultInterfaceUpdateCallback = func(event int) error
|
||||
)
|
||||
|
||||
const FlagAndroidVPNUpdate = 1 << iota
|
||||
const (
|
||||
EventInterfaceUpdate = 1
|
||||
EventAndroidVPNUpdate = 2
|
||||
)
|
||||
|
||||
type NetworkUpdateMonitor interface {
|
||||
Start() error
|
||||
Close() error
|
||||
RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback]
|
||||
UnregisterCallback(element *list.Element[NetworkUpdateCallback])
|
||||
E.Handler
|
||||
}
|
||||
|
||||
type DefaultInterfaceMonitor interface {
|
||||
Start() error
|
||||
Close() error
|
||||
DefaultInterface() *control.Interface
|
||||
DefaultInterfaceName(destination netip.Addr) string
|
||||
DefaultInterfaceIndex(destination netip.Addr) int
|
||||
OverrideAndroidVPN() bool
|
||||
AndroidVPNEnabled() bool
|
||||
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
|
||||
@@ -33,7 +39,6 @@ type DefaultInterfaceMonitor interface {
|
||||
}
|
||||
|
||||
type DefaultInterfaceMonitorOptions struct {
|
||||
InterfaceFinder control.InterfaceFinder
|
||||
OverrideAndroidVPN bool
|
||||
UnderNetworkExtension bool
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
continue
|
||||
}
|
||||
vpnEnabled = true
|
||||
if m.overrideAndroidVPN {
|
||||
if m.options.OverrideAndroidVPN {
|
||||
defaultTableIndex = rule.Table
|
||||
break
|
||||
}
|
||||
@@ -51,18 +51,22 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
return err
|
||||
}
|
||||
|
||||
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
||||
oldInterface := m.defaultInterfaceName
|
||||
oldIndex := m.defaultInterfaceIndex
|
||||
|
||||
m.defaultInterfaceName = link.Attrs().Name
|
||||
m.defaultInterfaceIndex = link.Attrs().Index
|
||||
|
||||
var event int
|
||||
if oldInterface != m.defaultInterfaceName || oldIndex != m.defaultInterfaceIndex {
|
||||
event |= EventInterfaceUpdate
|
||||
}
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled {
|
||||
return nil
|
||||
}
|
||||
var flags int
|
||||
if oldVPNEnabled != m.androidVPNEnabled {
|
||||
flags = FlagAndroidVPNUpdate
|
||||
event |= EventAndroidVPNUpdate
|
||||
}
|
||||
m.emit(newInterface, flags)
|
||||
if event != 0 {
|
||||
m.emit(event)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"golang.org/x/net/route"
|
||||
@@ -18,166 +18,132 @@ import (
|
||||
)
|
||||
|
||||
type networkUpdateMonitor struct {
|
||||
access sync.Mutex
|
||||
callbacks list.List[NetworkUpdateCallback]
|
||||
routeSocketFile *os.File
|
||||
closeOnce sync.Once
|
||||
done chan struct{}
|
||||
logger logger.Logger
|
||||
errorHandler E.Handler
|
||||
|
||||
access sync.Mutex
|
||||
callbacks list.List[NetworkUpdateCallback]
|
||||
routeSocket *os.File
|
||||
}
|
||||
|
||||
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
||||
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
|
||||
return &networkUpdateMonitor{
|
||||
logger: logger,
|
||||
done: make(chan struct{}),
|
||||
errorHandler: errorHandler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) Start() error {
|
||||
go m.loopUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) loopUpdate() {
|
||||
for {
|
||||
select {
|
||||
case <-m.done:
|
||||
return
|
||||
default:
|
||||
}
|
||||
err := m.loopUpdate0()
|
||||
if err != nil {
|
||||
m.logger.Error("listen network update: ", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) loopUpdate0() error {
|
||||
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = unix.SetNonblock(routeSocket, true)
|
||||
if err != nil {
|
||||
unix.Close(routeSocket)
|
||||
return err
|
||||
}
|
||||
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
|
||||
defer routeSocketFile.Close()
|
||||
m.routeSocketFile = routeSocketFile
|
||||
m.loopUpdate1(routeSocketFile)
|
||||
m.routeSocket = os.NewFile(uintptr(routeSocket), "route")
|
||||
go m.loopUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) loopUpdate1(routeSocketFile *os.File) {
|
||||
buffer := buf.NewPacket()
|
||||
defer buffer.Release()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-m.done:
|
||||
routeSocketFile.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
n, err := routeSocketFile.Read(buffer.FreeBytes())
|
||||
close(done)
|
||||
func (m *networkUpdateMonitor) loopUpdate() {
|
||||
rawConn, err := m.routeSocket.SyscallConn()
|
||||
if err != nil {
|
||||
m.errorHandler.NewError(context.Background(), E.Cause(err, "create raw route connection"))
|
||||
return
|
||||
}
|
||||
buffer.Truncate(n)
|
||||
messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, message := range messages {
|
||||
if _, isRouteMessage := message.(*route.RouteMessage); isRouteMessage {
|
||||
m.emit()
|
||||
return
|
||||
for {
|
||||
var innerErr error
|
||||
err = rawConn.Read(func(fd uintptr) (done bool) {
|
||||
var msg [2048]byte
|
||||
_, innerErr = unix.Read(int(fd), msg[:])
|
||||
return innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
m.emit()
|
||||
}
|
||||
if err != syscall.EAGAIN {
|
||||
m.errorHandler.NewError(context.Background(), E.Cause(err, "read route message"))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) Close() error {
|
||||
m.closeOnce.Do(func() {
|
||||
close(m.done)
|
||||
})
|
||||
return nil
|
||||
return common.Close(common.PtrOrNil(m.routeSocket))
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
var (
|
||||
defaultInterface *control.Interface
|
||||
err error
|
||||
)
|
||||
if m.underNetworkExtension {
|
||||
defaultInterface, err = m.getDefaultInterfaceBySocket()
|
||||
ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var defaultInterface *net.Interface
|
||||
for _, rawRouteMessage := range routeMessages {
|
||||
routeMessage := rawRouteMessage.(*route.RouteMessage)
|
||||
if len(routeMessage.Addrs) <= unix.RTAX_NETMASK {
|
||||
continue
|
||||
}
|
||||
destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr)
|
||||
if !isIPv4Destination {
|
||||
continue
|
||||
}
|
||||
if destination.IP != netip.IPv4Unspecified().As4() {
|
||||
continue
|
||||
}
|
||||
mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr)
|
||||
if !isIPv4Mask {
|
||||
continue
|
||||
}
|
||||
ones, _ := net.IPMask(mask.IP[:]).Size()
|
||||
if ones != 0 {
|
||||
continue
|
||||
}
|
||||
routeInterface, err := net.InterfaceByIndex(routeMessage.Index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
if routeMessage.Flags&unix.RTF_UP == 0 {
|
||||
continue
|
||||
}
|
||||
routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage)
|
||||
if err != nil {
|
||||
return err
|
||||
if routeMessage.Flags&unix.RTF_GATEWAY == 0 {
|
||||
continue
|
||||
}
|
||||
for _, rawRouteMessage := range routeMessages {
|
||||
routeMessage := rawRouteMessage.(*route.RouteMessage)
|
||||
if len(routeMessage.Addrs) <= unix.RTAX_NETMASK {
|
||||
continue
|
||||
}
|
||||
destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr)
|
||||
if !isIPv4Destination {
|
||||
continue
|
||||
}
|
||||
if destination.IP != netip.IPv4Unspecified().As4() {
|
||||
continue
|
||||
}
|
||||
mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr)
|
||||
if !isIPv4Mask {
|
||||
continue
|
||||
}
|
||||
ones, _ := net.IPMask(mask.IP[:]).Size()
|
||||
if ones != 0 {
|
||||
continue
|
||||
}
|
||||
routeInterface, err := m.interfaceFinder.ByIndex(routeMessage.Index)
|
||||
if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
|
||||
continue
|
||||
}
|
||||
defaultInterface = routeInterface
|
||||
break
|
||||
}
|
||||
if defaultInterface == nil {
|
||||
if m.options.UnderNetworkExtension {
|
||||
defaultInterface, err = getDefaultInterfaceBySocket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if routeMessage.Flags&unix.RTF_UP == 0 {
|
||||
continue
|
||||
}
|
||||
if routeMessage.Flags&unix.RTF_GATEWAY == 0 {
|
||||
continue
|
||||
}
|
||||
// if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
|
||||
//continue
|
||||
//}
|
||||
defaultInterface = routeInterface
|
||||
break
|
||||
}
|
||||
}
|
||||
if defaultInterface == nil {
|
||||
return ErrNoRoute
|
||||
}
|
||||
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
|
||||
}
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||
oldInterface := m.defaultInterfaceName
|
||||
oldIndex := m.defaultInterfaceIndex
|
||||
m.defaultInterfaceIndex = defaultInterface.Index
|
||||
m.defaultInterfaceName = defaultInterface.Name
|
||||
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
|
||||
return nil
|
||||
}
|
||||
m.emit(newInterface, 0)
|
||||
m.emit(EventInterfaceUpdate)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interface, error) {
|
||||
func getDefaultInterfaceBySocket() (*net.Interface, error) {
|
||||
socketFd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create file descriptor")
|
||||
@@ -188,8 +154,6 @@ func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interf
|
||||
Port: 80,
|
||||
})
|
||||
result := make(chan netip.Addr, 1)
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
go func() {
|
||||
for {
|
||||
sockname, sockErr := unix.Getsockname(socketFd)
|
||||
@@ -202,13 +166,8 @@ func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interf
|
||||
}
|
||||
addr := netip.AddrFrom4(sockaddr.Addr)
|
||||
if addr.IsUnspecified() {
|
||||
select {
|
||||
case <-done:
|
||||
break
|
||||
default:
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
continue
|
||||
}
|
||||
result <- addr
|
||||
break
|
||||
@@ -218,7 +177,26 @@ func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interf
|
||||
select {
|
||||
case selectedAddr = <-result:
|
||||
case <-time.After(time.Second):
|
||||
return nil, nil
|
||||
return nil, os.ErrDeadlineExceeded
|
||||
}
|
||||
return m.interfaceFinder.ByAddr(selectedAddr)
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "net.Interfaces")
|
||||
}
|
||||
for _, netInterface := range interfaces {
|
||||
interfaceAddrs, err := netInterface.Addrs()
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "net.Interfaces.Addrs")
|
||||
}
|
||||
for _, interfaceAddr := range interfaceAddrs {
|
||||
ipNet, isIPNet := interfaceAddr.(*net.IPNet)
|
||||
if !isIPNet {
|
||||
continue
|
||||
}
|
||||
if ipNet.Contains(selectedAddr.AsSlice()) {
|
||||
return &netInterface, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, E.New("no interface found for address ", selectedAddr)
|
||||
}
|
||||
|
||||
@@ -2,56 +2,30 @@ package tun
|
||||
|
||||
import (
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/netlink"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type networkUpdateMonitor struct {
|
||||
routeUpdate chan netlink.RouteUpdate
|
||||
linkUpdate chan netlink.LinkUpdate
|
||||
close chan struct{}
|
||||
routeUpdate chan netlink.RouteUpdate
|
||||
linkUpdate chan netlink.LinkUpdate
|
||||
close chan struct{}
|
||||
errorHandler E.Handler
|
||||
|
||||
access sync.Mutex
|
||||
callbacks list.List[NetworkUpdateCallback]
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
var ErrNetlinkBanned = E.New(
|
||||
"netlink socket in Android is banned by Google, " +
|
||||
"use the root or system (ADB) user to run sing-box, " +
|
||||
"or switch to the sing-box Android graphical interface client",
|
||||
)
|
||||
|
||||
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
||||
monitor := &networkUpdateMonitor{
|
||||
routeUpdate: make(chan netlink.RouteUpdate, 2),
|
||||
linkUpdate: make(chan netlink.LinkUpdate, 2),
|
||||
close: make(chan struct{}),
|
||||
logger: logger,
|
||||
}
|
||||
// check is netlink banned by google
|
||||
if runtime.GOOS == "android" {
|
||||
netlinkSocket, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return nil, ErrNetlinkBanned
|
||||
}
|
||||
err = unix.Bind(netlinkSocket, &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
})
|
||||
unix.Close(netlinkSocket)
|
||||
if err != nil {
|
||||
return nil, ErrNetlinkBanned
|
||||
}
|
||||
}
|
||||
return monitor, nil
|
||||
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
|
||||
return &networkUpdateMonitor{
|
||||
routeUpdate: make(chan netlink.RouteUpdate, 2),
|
||||
linkUpdate: make(chan netlink.LinkUpdate, 2),
|
||||
close: make(chan struct{}),
|
||||
errorHandler: errorHandler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) Start() error {
|
||||
@@ -68,9 +42,6 @@ func (m *networkUpdateMonitor) Start() error {
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) loopUpdate() {
|
||||
const minDuration = time.Second
|
||||
timer := time.NewTimer(minDuration)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.close:
|
||||
@@ -79,12 +50,6 @@ func (m *networkUpdateMonitor) loopUpdate() {
|
||||
case <-m.linkUpdate:
|
||||
}
|
||||
m.emit()
|
||||
select {
|
||||
case <-m.close:
|
||||
return
|
||||
case <-timer.C:
|
||||
timer.Reset(minDuration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -25,16 +25,17 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
return err
|
||||
}
|
||||
|
||||
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
|
||||
}
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||
oldInterface := m.defaultInterfaceName
|
||||
oldIndex := m.defaultInterfaceIndex
|
||||
|
||||
m.defaultInterfaceName = link.Attrs().Name
|
||||
m.defaultInterfaceIndex = link.Attrs().Index
|
||||
|
||||
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
|
||||
return nil
|
||||
}
|
||||
m.emit(newInterface, 0)
|
||||
m.emit(EventInterfaceUpdate)
|
||||
return nil
|
||||
}
|
||||
return ErrNoRoute
|
||||
return E.New("no route to internet")
|
||||
}
|
||||
|
||||
@@ -5,13 +5,13 @@ package tun
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
||||
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
|
||||
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
@@ -3,13 +3,15 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
)
|
||||
|
||||
@@ -30,67 +32,87 @@ func (m *networkUpdateMonitor) emit() {
|
||||
callbacks := m.callbacks.Array()
|
||||
m.access.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback()
|
||||
err := callback()
|
||||
if err != nil {
|
||||
m.NewError(context.Background(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *networkUpdateMonitor) NewError(ctx context.Context, err error) {
|
||||
m.errorHandler.NewError(ctx, err)
|
||||
}
|
||||
|
||||
type defaultInterfaceMonitor struct {
|
||||
interfaceFinder control.InterfaceFinder
|
||||
overrideAndroidVPN bool
|
||||
underNetworkExtension bool
|
||||
defaultInterface atomic.Pointer[control.Interface]
|
||||
options DefaultInterfaceMonitorOptions
|
||||
networkAddresses []networkAddress
|
||||
defaultInterfaceName string
|
||||
defaultInterfaceIndex int
|
||||
androidVPNEnabled bool
|
||||
noRoute bool
|
||||
networkMonitor NetworkUpdateMonitor
|
||||
checkUpdateTimer *time.Timer
|
||||
element *list.Element[NetworkUpdateCallback]
|
||||
access sync.Mutex
|
||||
callbacks list.List[DefaultInterfaceUpdateCallback]
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
|
||||
type networkAddress struct {
|
||||
interfaceName string
|
||||
interfaceIndex int
|
||||
addresses []netip.Prefix
|
||||
}
|
||||
|
||||
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
|
||||
return &defaultInterfaceMonitor{
|
||||
interfaceFinder: options.InterfaceFinder,
|
||||
overrideAndroidVPN: options.OverrideAndroidVPN,
|
||||
underNetworkExtension: options.UnderNetworkExtension,
|
||||
options: options,
|
||||
networkMonitor: networkMonitor,
|
||||
logger: logger,
|
||||
defaultInterfaceIndex: -1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) Start() error {
|
||||
m.postCheckUpdate()
|
||||
err := m.checkUpdate()
|
||||
if err != nil {
|
||||
m.networkMonitor.NewError(context.Background(), err)
|
||||
}
|
||||
m.element = m.networkMonitor.RegisterCallback(m.delayCheckUpdate)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) delayCheckUpdate() {
|
||||
if m.checkUpdateTimer == nil {
|
||||
m.checkUpdateTimer = time.AfterFunc(time.Second, m.postCheckUpdate)
|
||||
} else {
|
||||
m.checkUpdateTimer.Reset(time.Second)
|
||||
func (m *defaultInterfaceMonitor) delayCheckUpdate() error {
|
||||
time.Sleep(time.Second)
|
||||
err := m.updateInterfaces()
|
||||
if err != nil {
|
||||
m.networkMonitor.NewError(context.Background(), E.Cause(err, "update interfaces"))
|
||||
}
|
||||
return m.checkUpdate()
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) postCheckUpdate() {
|
||||
err := m.interfaceFinder.Update()
|
||||
func (m *defaultInterfaceMonitor) updateInterfaces() error {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
m.logger.Error("update interface: ", err)
|
||||
return
|
||||
return err
|
||||
}
|
||||
err = m.checkUpdate()
|
||||
if errors.Is(err, ErrNoRoute) {
|
||||
if !m.noRoute {
|
||||
m.noRoute = true
|
||||
m.defaultInterface.Store(nil)
|
||||
m.emit(nil, 0)
|
||||
var addresses []networkAddress
|
||||
for _, iif := range interfaces {
|
||||
var netAddresses []net.Addr
|
||||
netAddresses, err = iif.Addrs()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if err != nil {
|
||||
m.logger.Error("check interface: ", err)
|
||||
} else {
|
||||
m.noRoute = false
|
||||
var address networkAddress
|
||||
address.interfaceName = iif.Name
|
||||
address.interfaceIndex = iif.Index
|
||||
address.addresses = common.Map(common.FilterIsInstance(netAddresses, func(it net.Addr) (*net.IPNet, bool) {
|
||||
value, loaded := it.(*net.IPNet)
|
||||
return value, loaded
|
||||
}), func(it *net.IPNet) netip.Prefix {
|
||||
bits, _ := it.Mask.Size()
|
||||
return netip.PrefixFrom(M.AddrFromIP(it.IP), bits)
|
||||
})
|
||||
addresses = append(addresses, address)
|
||||
}
|
||||
m.networkAddresses = addresses
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) Close() error {
|
||||
@@ -100,12 +122,36 @@ func (m *defaultInterfaceMonitor) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) DefaultInterface() *control.Interface {
|
||||
return m.defaultInterface.Load()
|
||||
func (m *defaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) string {
|
||||
for _, address := range m.networkAddresses {
|
||||
for _, prefix := range address.addresses {
|
||||
if prefix.Contains(destination) {
|
||||
return address.interfaceName
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.defaultInterfaceIndex == -1 {
|
||||
m.checkUpdate()
|
||||
}
|
||||
return m.defaultInterfaceName
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr) int {
|
||||
for _, address := range m.networkAddresses {
|
||||
for _, prefix := range address.addresses {
|
||||
if prefix.Contains(destination) {
|
||||
return address.interfaceIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
if m.defaultInterfaceIndex == -1 {
|
||||
m.checkUpdate()
|
||||
}
|
||||
return m.defaultInterfaceIndex
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) OverrideAndroidVPN() bool {
|
||||
return m.overrideAndroidVPN
|
||||
return m.options.OverrideAndroidVPN
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) AndroidVPNEnabled() bool {
|
||||
@@ -124,11 +170,14 @@ func (m *defaultInterfaceMonitor) UnregisterCallback(element *list.Element[Defau
|
||||
m.callbacks.Remove(element)
|
||||
}
|
||||
|
||||
func (m *defaultInterfaceMonitor) emit(defaultInterface *control.Interface, flags int) {
|
||||
func (m *defaultInterfaceMonitor) emit(event int) {
|
||||
m.access.Lock()
|
||||
callbacks := m.callbacks.Array()
|
||||
m.access.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback(defaultInterface, flags)
|
||||
err := callback(event)
|
||||
if err != nil {
|
||||
m.networkMonitor.NewError(context.Background(), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/winipcfg"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
@@ -14,15 +13,15 @@ import (
|
||||
type networkUpdateMonitor struct {
|
||||
routeListener *winipcfg.RouteChangeCallback
|
||||
interfaceListener *winipcfg.InterfaceChangeCallback
|
||||
errorHandler E.Handler
|
||||
|
||||
access sync.Mutex
|
||||
callbacks list.List[NetworkUpdateCallback]
|
||||
logger logger.Logger
|
||||
}
|
||||
|
||||
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
|
||||
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
|
||||
return &networkUpdateMonitor{
|
||||
logger: logger,
|
||||
errorHandler: errorHandler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -77,16 +76,12 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
continue
|
||||
}
|
||||
|
||||
if ifrow.Type == winipcfg.IfTypePropVirtual || ifrow.Type == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
iface, err := row.InterfaceLUID.IPInterface(windows.AF_INET)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if !iface.Connected {
|
||||
if ifrow.Type == winipcfg.IfTypePropVirtual || ifrow.Type == winipcfg.IfTypeSoftwareLoopback {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -102,14 +97,16 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
|
||||
return ErrNoRoute
|
||||
}
|
||||
|
||||
newInterface, err := m.interfaceFinder.ByIndex(index)
|
||||
if err != nil {
|
||||
return E.Cause(err, "find updated interface: ", alias)
|
||||
}
|
||||
oldInterface := m.defaultInterface.Swap(newInterface)
|
||||
if oldInterface != nil && oldInterface.Equals(*newInterface) {
|
||||
oldInterface := m.defaultInterfaceName
|
||||
oldIndex := m.defaultInterfaceIndex
|
||||
|
||||
m.defaultInterfaceName = alias
|
||||
m.defaultInterfaceIndex = index
|
||||
|
||||
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
|
||||
return nil
|
||||
}
|
||||
m.emit(newInterface, 0)
|
||||
|
||||
m.emit(EventInterfaceUpdate)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,21 +3,20 @@ package tun
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing-tun/internal/clashtcpip"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func NetworkName(network uint8) string {
|
||||
switch tcpip.TransportProtocolNumber(network) {
|
||||
case header.TCPProtocolNumber:
|
||||
switch network {
|
||||
case clashtcpip.TCP:
|
||||
return N.NetworkTCP
|
||||
case header.UDPProtocolNumber:
|
||||
case clashtcpip.UDP:
|
||||
return N.NetworkUDP
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
case clashtcpip.ICMP:
|
||||
return N.NetworkICMPv4
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
case clashtcpip.ICMPv6:
|
||||
return N.NetworkICMPv6
|
||||
}
|
||||
return F.ToString(network)
|
||||
@@ -26,13 +25,13 @@ func NetworkName(network uint8) string {
|
||||
func NetworkFromName(name string) uint8 {
|
||||
switch name {
|
||||
case N.NetworkTCP:
|
||||
return uint8(header.TCPProtocolNumber)
|
||||
return clashtcpip.TCP
|
||||
case N.NetworkUDP:
|
||||
return uint8(header.UDPProtocolNumber)
|
||||
return clashtcpip.UDP
|
||||
case N.NetworkICMPv4:
|
||||
return uint8(header.ICMPv4ProtocolNumber)
|
||||
return clashtcpip.ICMP
|
||||
case N.NetworkICMPv6:
|
||||
return uint8(header.ICMPv6ProtocolNumber)
|
||||
return clashtcpip.ICMPv6
|
||||
}
|
||||
parseNetwork, err := strconv.ParseUint(name, 10, 8)
|
||||
if err != nil {
|
||||
|
||||
11
packages.go
11
packages.go
@@ -1,6 +1,6 @@
|
||||
package tun
|
||||
|
||||
import "github.com/sagernet/sing/common/logger"
|
||||
import E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
type PackageManager interface {
|
||||
Start() error
|
||||
@@ -11,14 +11,7 @@ type PackageManager interface {
|
||||
SharedPackageByID(id uint32) (string, bool)
|
||||
}
|
||||
|
||||
type PackageManagerOptions struct {
|
||||
Callback PackageManagerCallback
|
||||
|
||||
// Logger is the logger to log errors
|
||||
// optional
|
||||
Logger logger.Logger
|
||||
}
|
||||
|
||||
type PackageManagerCallback interface {
|
||||
OnPackagesUpdated(packages int, sharedUsers int)
|
||||
E.Handler
|
||||
}
|
||||
|
||||
@@ -2,33 +2,30 @@ package tun
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"github.com/sagernet/fswatch"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/abx"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
type packageManager struct {
|
||||
callback PackageManagerCallback
|
||||
logger logger.Logger
|
||||
watcher *fswatch.Watcher
|
||||
watcher *fsnotify.Watcher
|
||||
idByPackage map[string]uint32
|
||||
sharedByPackage map[string]uint32
|
||||
packageById map[uint32]string
|
||||
sharedById map[uint32]string
|
||||
}
|
||||
|
||||
func NewPackageManager(options PackageManagerOptions) (PackageManager, error) {
|
||||
return &packageManager{
|
||||
callback: options.Callback,
|
||||
logger: options.Logger,
|
||||
}, nil
|
||||
func NewPackageManager(callback PackageManagerCallback) (PackageManager, error) {
|
||||
return &packageManager{callback: callback}, nil
|
||||
}
|
||||
|
||||
func (m *packageManager) Start() error {
|
||||
@@ -38,33 +35,42 @@ func (m *packageManager) Start() error {
|
||||
}
|
||||
err = m.startWatcher()
|
||||
if err != nil {
|
||||
m.logger.Error(E.Cause(err, "create watcher for packages list"))
|
||||
m.callback.NewError(context.Background(), E.Cause(err, "create fsnotify watcher"))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *packageManager) startWatcher() error {
|
||||
watcher, err := fswatch.NewWatcher(fswatch.Options{
|
||||
Path: []string{"/data/system/packages.xml"},
|
||||
Direct: true,
|
||||
Callback: m.packagesUpdated,
|
||||
Logger: m.logger,
|
||||
})
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = watcher.Start()
|
||||
err = watcher.Add("/data/system/packages.xml")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.watcher = watcher
|
||||
go m.loopUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *packageManager) packagesUpdated(path string) {
|
||||
err := m.updatePackages()
|
||||
if err != nil {
|
||||
m.logger.Error(E.Cause(err, "update packages"))
|
||||
func (m *packageManager) loopUpdate() {
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-m.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
err := m.updatePackages()
|
||||
if err != nil {
|
||||
m.callback.NewError(context.Background(), E.Cause(err, "update packages"))
|
||||
}
|
||||
case err, ok := <-m.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
m.callback.NewError(context.Background(), E.Cause(err, "fsnotify error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,6 @@ package tun
|
||||
|
||||
import "os"
|
||||
|
||||
func NewPackageManager(options PackageManagerOptions) (PackageManager, error) {
|
||||
func NewPackageManager(callback PackageManagerCallback) (PackageManager, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
36
redirect.go
36
redirect.go
@@ -1,36 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultAutoRedirectInputMark = 0x2023
|
||||
DefaultAutoRedirectOutputMark = 0x2024
|
||||
)
|
||||
|
||||
type AutoRedirect interface {
|
||||
Start() error
|
||||
Close() error
|
||||
UpdateRouteAddressSet()
|
||||
}
|
||||
|
||||
type AutoRedirectOptions struct {
|
||||
TunOptions *Options
|
||||
Context context.Context
|
||||
Handler N.TCPConnectionHandlerEx
|
||||
Logger logger.Logger
|
||||
NetworkMonitor NetworkUpdateMonitor
|
||||
InterfaceFinder control.InterfaceFinder
|
||||
TableName string
|
||||
DisableNFTables bool
|
||||
CustomRedirectPort func() int
|
||||
RouteAddressSet *[]*netipx.IPSet
|
||||
RouteExcludeAddressSet *[]*netipx.IPSet
|
||||
}
|
||||
@@ -1,81 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
F "github.com/sagernet/sing/common/format"
|
||||
)
|
||||
|
||||
func (r *autoRedirect) setupIPTables() error {
|
||||
if r.enableIPv4 {
|
||||
err := r.setupIPTablesForFamily(r.iptablesPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
err := r.setupIPTablesForFamily(r.ip6tablesPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) setupIPTablesForFamily(iptablesPath string) error {
|
||||
tableNameOutput := r.tableName + "-output"
|
||||
redirectPort := r.redirectPort()
|
||||
// OUTPUT
|
||||
err := r.runShell(iptablesPath, "-t nat -N", tableNameOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-t nat -A", tableNameOutput,
|
||||
"-p tcp -o", r.tunOptions.Name,
|
||||
"-j REDIRECT --to-ports", redirectPort)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.runShell(iptablesPath, "-t nat -I OUTPUT -j", tableNameOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) cleanupIPTables() {
|
||||
if r.enableIPv4 {
|
||||
r.cleanupIPTablesForFamily(r.iptablesPath)
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
r.cleanupIPTablesForFamily(r.ip6tablesPath)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *autoRedirect) cleanupIPTablesForFamily(iptablesPath string) {
|
||||
tableNameOutput := r.tableName + "-output"
|
||||
|
||||
_ = r.runShell(iptablesPath, "-t nat -D OUTPUT -j", tableNameOutput)
|
||||
_ = r.runShell(iptablesPath, "-t nat -F", tableNameOutput)
|
||||
_ = r.runShell(iptablesPath, "-t nat -X", tableNameOutput)
|
||||
}
|
||||
|
||||
func (r *autoRedirect) runShell(commands ...any) error {
|
||||
commandStr := strings.Join(F.MapToString(commands), " ")
|
||||
var command *exec.Cmd
|
||||
if r.androidSu {
|
||||
command = exec.Command(r.suPath, "-c", commandStr)
|
||||
} else {
|
||||
commandArray := strings.Split(commandStr, " ")
|
||||
command = exec.Command(commandArray[0], commandArray[1:]...)
|
||||
}
|
||||
combinedOutput, err := command.CombinedOutput()
|
||||
if err != nil {
|
||||
return E.Extend(err, F.ToString(commandStr, ": ", string(combinedOutput)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
type autoRedirect struct {
|
||||
tunOptions *Options
|
||||
ctx context.Context
|
||||
handler N.TCPConnectionHandlerEx
|
||||
logger logger.Logger
|
||||
tableName string
|
||||
networkMonitor NetworkUpdateMonitor
|
||||
networkListener *list.Element[NetworkUpdateCallback]
|
||||
interfaceFinder control.InterfaceFinder
|
||||
localAddresses []netip.Prefix
|
||||
customRedirectPortFunc func() int
|
||||
customRedirectPort int
|
||||
redirectServer *redirectServer
|
||||
enableIPv4 bool
|
||||
enableIPv6 bool
|
||||
iptablesPath string
|
||||
ip6tablesPath string
|
||||
useNFTables bool
|
||||
androidSu bool
|
||||
suPath string
|
||||
routeAddressSet *[]*netipx.IPSet
|
||||
routeExcludeAddressSet *[]*netipx.IPSet
|
||||
}
|
||||
|
||||
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||
return &autoRedirect{
|
||||
tunOptions: options.TunOptions,
|
||||
ctx: options.Context,
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
networkMonitor: options.NetworkMonitor,
|
||||
interfaceFinder: options.InterfaceFinder,
|
||||
tableName: options.TableName,
|
||||
useNFTables: runtime.GOOS != "android" && !options.DisableNFTables,
|
||||
customRedirectPortFunc: options.CustomRedirectPort,
|
||||
routeAddressSet: options.RouteAddressSet,
|
||||
routeExcludeAddressSet: options.RouteExcludeAddressSet,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) Start() error {
|
||||
var err error
|
||||
if runtime.GOOS == "android" {
|
||||
r.enableIPv4 = true
|
||||
r.iptablesPath = "/system/bin/iptables"
|
||||
userId := os.Getuid()
|
||||
if userId != 0 {
|
||||
r.androidSu = true
|
||||
for _, suPath := range []string{
|
||||
"su",
|
||||
"/system/bin/su",
|
||||
} {
|
||||
r.suPath, err = exec.LookPath(suPath)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if r.useNFTables {
|
||||
err = r.initializeNFTables()
|
||||
if err != nil {
|
||||
return E.Cause(err, "missing nftables support")
|
||||
}
|
||||
}
|
||||
if len(r.tunOptions.Inet4Address) > 0 {
|
||||
r.enableIPv4 = true
|
||||
if !r.useNFTables {
|
||||
r.iptablesPath, err = exec.LookPath("iptables")
|
||||
if err != nil {
|
||||
return E.Cause(err, "iptables is required")
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(r.tunOptions.Inet6Address) > 0 {
|
||||
r.enableIPv6 = true
|
||||
if !r.useNFTables {
|
||||
r.ip6tablesPath, err = exec.LookPath("ip6tables")
|
||||
if err != nil {
|
||||
if !r.enableIPv4 {
|
||||
return E.Cause(err, "ip6tables is required")
|
||||
} else {
|
||||
r.enableIPv6 = false
|
||||
r.logger.Error("device has no ip6tables nat support: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.customRedirectPortFunc != nil {
|
||||
r.customRedirectPort = r.customRedirectPortFunc()
|
||||
}
|
||||
if r.customRedirectPort == 0 {
|
||||
var listenAddr netip.Addr
|
||||
if runtime.GOOS == "android" {
|
||||
listenAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
} else if r.enableIPv6 {
|
||||
listenAddr = netip.IPv6Unspecified()
|
||||
} else {
|
||||
listenAddr = netip.IPv4Unspecified()
|
||||
}
|
||||
server := newRedirectServer(r.ctx, r.handler, r.logger, listenAddr)
|
||||
err := server.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "start redirect server")
|
||||
}
|
||||
r.redirectServer = server
|
||||
}
|
||||
if r.useNFTables {
|
||||
r.cleanupNFTables()
|
||||
err = r.setupNFTables()
|
||||
} else {
|
||||
r.cleanupIPTables()
|
||||
err = r.setupIPTables()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *autoRedirect) Close() error {
|
||||
if r.useNFTables {
|
||||
r.cleanupNFTables()
|
||||
} else {
|
||||
r.cleanupIPTables()
|
||||
}
|
||||
return common.Close(
|
||||
common.PtrOrNil(r.redirectServer),
|
||||
)
|
||||
}
|
||||
|
||||
func (r *autoRedirect) UpdateRouteAddressSet() {
|
||||
if r.useNFTables {
|
||||
err := r.nftablesUpdateRouteAddressSet()
|
||||
if err != nil {
|
||||
r.logger.Error("update route address set: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *autoRedirect) initializeNFTables() error {
|
||||
nft, err := nftables.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer nft.CloseLasting()
|
||||
_, err = nft.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.useNFTables = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) redirectPort() uint16 {
|
||||
if r.customRedirectPort > 0 {
|
||||
return uint16(r.customRedirectPort)
|
||||
}
|
||||
return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port()
|
||||
}
|
||||
@@ -1,278 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/nftables/binaryutil"
|
||||
"github.com/sagernet/nftables/expr"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func (r *autoRedirect) setupNFTables() error {
|
||||
nft, err := nftables.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer nft.CloseLasting()
|
||||
|
||||
table := nft.AddTable(&nftables.Table{
|
||||
Name: r.tableName,
|
||||
Family: nftables.TableFamilyINet,
|
||||
})
|
||||
|
||||
err = r.nftablesCreateAddressSets(nft, table, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.interfaceFinder.Update()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.localAddresses = common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
|
||||
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
|
||||
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()
|
||||
})
|
||||
})
|
||||
err = r.nftablesCreateLocalAddressSets(nft, table, r.localAddresses, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
|
||||
if !skipOutput {
|
||||
chainOutput := nft.AddChain(&nftables.Chain{
|
||||
Name: "output",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
if r.tunOptions.AutoRedirectMarkMode {
|
||||
err = r.nftablesCreateExcludeRules(nft, table, chainOutput)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.nftablesCreateUnreachable(nft, table, chainOutput)
|
||||
r.nftablesCreateRedirect(nft, table, chainOutput)
|
||||
|
||||
chainOutputUDP := nft.AddChain(&nftables.Chain{
|
||||
Name: "output_udp",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookOutput,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeRoute,
|
||||
})
|
||||
err = r.nftablesCreateExcludeRules(nft, table, chainOutputUDP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.nftablesCreateUnreachable(nft, table, chainOutputUDP)
|
||||
r.nftablesCreateMark(nft, table, chainOutputUDP)
|
||||
} else {
|
||||
r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
}, &expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
chainPreRouting := nft.AddChain(&nftables.Chain{
|
||||
Name: "prerouting",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 1),
|
||||
Type: nftables.ChainTypeNAT,
|
||||
})
|
||||
err = r.nftablesCreateExcludeRules(nft, table, chainPreRouting)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.nftablesCreateUnreachable(nft, table, chainPreRouting)
|
||||
r.nftablesCreateRedirect(nft, table, chainPreRouting)
|
||||
if r.tunOptions.AutoRedirectMarkMode {
|
||||
r.nftablesCreateMark(nft, table, chainPreRouting)
|
||||
}
|
||||
|
||||
if r.tunOptions.AutoRedirectMarkMode {
|
||||
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
|
||||
Name: "prerouting_udp",
|
||||
Table: table,
|
||||
Hooknum: nftables.ChainHookPrerouting,
|
||||
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_UDP},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
},
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chainPreRoutingUDP,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
err = r.configureOpenWRTFirewall4(nft, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = nft.Flush()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r.networkListener = r.networkMonitor.RegisterCallback(func() {
|
||||
err = r.nftablesUpdateLocalAddressSet()
|
||||
if err != nil {
|
||||
r.logger.Error("update local address set: ", err)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO; test is this works
|
||||
func (r *autoRedirect) nftablesUpdateLocalAddressSet() error {
|
||||
newLocalAddresses := common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
|
||||
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
|
||||
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()
|
||||
})
|
||||
})
|
||||
if slices.Equal(newLocalAddresses, r.localAddresses) {
|
||||
return nil
|
||||
}
|
||||
nft, err := nftables.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer nft.CloseLasting()
|
||||
table, err := nft.ListTableOfFamily(r.tableName, nftables.TableFamilyINet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.nftablesCreateLocalAddressSets(nft, table, newLocalAddresses, r.localAddresses)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.localAddresses = newLocalAddresses
|
||||
return nft.Flush()
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesUpdateRouteAddressSet() error {
|
||||
nft, err := nftables.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer nft.CloseLasting()
|
||||
table, err := nft.ListTableOfFamily(r.tableName, nftables.TableFamilyINet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = r.nftablesCreateAddressSets(nft, table, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nft.Flush()
|
||||
}
|
||||
|
||||
func (r *autoRedirect) cleanupNFTables() {
|
||||
if r.networkListener != nil {
|
||||
r.networkMonitor.UnregisterCallback(r.networkListener)
|
||||
}
|
||||
nft, err := nftables.New()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
nft.DelTable(&nftables.Table{
|
||||
Name: r.tableName,
|
||||
Family: nftables.TableFamilyINet,
|
||||
})
|
||||
common.Must(r.configureOpenWRTFirewall4(nft, true))
|
||||
_ = nft.Flush()
|
||||
_ = nft.CloseLasting()
|
||||
}
|
||||
@@ -1,172 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/nftables/expr"
|
||||
|
||||
"go4.org/netipx"
|
||||
)
|
||||
|
||||
func nftablesIfname(n string) []byte {
|
||||
b := make([]byte, 16)
|
||||
copy(b, n+"\x00")
|
||||
return b
|
||||
}
|
||||
|
||||
func nftablesCreateExcludeDestinationIPSet(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
id uint32, name string, family nftables.TableFamily, invert bool,
|
||||
) {
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{byte(family)},
|
||||
},
|
||||
}
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 16,
|
||||
Len: 4,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 24,
|
||||
Len: 16,
|
||||
},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: id,
|
||||
SetName: name,
|
||||
Invert: invert,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
})
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
})
|
||||
}
|
||||
|
||||
func nftablesCreateIPSet(
|
||||
nft *nftables.Conn, table *nftables.Table,
|
||||
id uint32, name string, family nftables.TableFamily,
|
||||
setList []*netipx.IPSet, prefixList []netip.Prefix, appendDefault bool, update bool,
|
||||
) (*nftables.Set, error) {
|
||||
var builder netipx.IPSetBuilder
|
||||
for _, prefix := range prefixList {
|
||||
builder.AddPrefix(prefix)
|
||||
}
|
||||
for _, set := range setList {
|
||||
builder.AddSet(set)
|
||||
}
|
||||
ipSet, err := builder.IPSet()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipRanges := ipSet.Ranges()
|
||||
setElements := make([]nftables.SetElement, 0, len(ipRanges))
|
||||
for _, rr := range ipRanges {
|
||||
if (family == nftables.TableFamilyIPv4) != rr.From().Is4() {
|
||||
continue
|
||||
}
|
||||
endAddr := rr.To().Next()
|
||||
if !endAddr.IsValid() {
|
||||
endAddr = rr.From()
|
||||
}
|
||||
setElements = append(setElements, nftables.SetElement{
|
||||
Key: rr.From().AsSlice(),
|
||||
})
|
||||
setElements = append(setElements, nftables.SetElement{
|
||||
Key: endAddr.AsSlice(),
|
||||
IntervalEnd: true,
|
||||
})
|
||||
}
|
||||
if len(prefixList) == 0 && appendDefault {
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
setElements = append(setElements, nftables.SetElement{
|
||||
Key: netip.IPv4Unspecified().AsSlice(),
|
||||
}, nftables.SetElement{
|
||||
Key: netip.IPv4Unspecified().AsSlice(),
|
||||
IntervalEnd: true,
|
||||
})
|
||||
} else {
|
||||
setElements = append(setElements, nftables.SetElement{
|
||||
Key: netip.IPv6Unspecified().AsSlice(),
|
||||
}, nftables.SetElement{
|
||||
Key: netip.IPv6Unspecified().AsSlice(),
|
||||
IntervalEnd: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
var keyType nftables.SetDatatype
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
keyType = nftables.TypeIPAddr
|
||||
} else {
|
||||
keyType = nftables.TypeIP6Addr
|
||||
}
|
||||
mySet := &nftables.Set{
|
||||
Table: table,
|
||||
ID: id,
|
||||
Name: name,
|
||||
Interval: true,
|
||||
KeyType: keyType,
|
||||
}
|
||||
if id == 0 {
|
||||
mySet.Anonymous = true
|
||||
mySet.Constant = true
|
||||
}
|
||||
if id == 0 {
|
||||
err := nft.AddSet(mySet, setElements)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mySet, nil
|
||||
} else if update {
|
||||
nft.FlushSet(mySet)
|
||||
} else {
|
||||
err := nft.AddSet(mySet, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
for len(setElements) > 0 {
|
||||
toAdd := setElements
|
||||
if len(toAdd) > 1000 {
|
||||
toAdd = toAdd[:1000]
|
||||
}
|
||||
setElements = setElements[len(toAdd):]
|
||||
err := nft.SetAddElements(mySet, toAdd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = nft.Flush()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return mySet, nil
|
||||
}
|
||||
@@ -1,735 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/nftables/binaryutil"
|
||||
"github.com/sagernet/nftables/expr"
|
||||
"github.com/sagernet/nftables/userdata"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/ranges"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
//go:linkname allocSetID github.com/sagernet/nftables.allocSetID
|
||||
var allocSetID uint32
|
||||
|
||||
func init() {
|
||||
allocSetID = 6
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateAddressSets(
|
||||
nft *nftables.Conn, table *nftables.Table,
|
||||
update bool,
|
||||
) error {
|
||||
routeAddressSet := *r.routeAddressSet
|
||||
routeExcludeAddressSet := *r.routeExcludeAddressSet
|
||||
if len(routeAddressSet) == 0 && len(routeExcludeAddressSet) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(routeAddressSet) > 0 {
|
||||
if r.enableIPv4 {
|
||||
_, err := nftablesCreateIPSet(nft, table, 1, "inet4_route_address_set", nftables.TableFamilyIPv4, routeAddressSet, nil, true, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
_, err := nftablesCreateIPSet(nft, table, 2, "inet6_route_address_set", nftables.TableFamilyIPv6, routeAddressSet, nil, true, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeExcludeAddressSet) > 0 {
|
||||
if r.enableIPv4 {
|
||||
_, err := nftablesCreateIPSet(nft, table, 3, "inet4_route_exclude_address_set", nftables.TableFamilyIPv4, routeExcludeAddressSet, nil, false, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
_, err := nftablesCreateIPSet(nft, table, 4, "inet6_route_exclude_address_set", nftables.TableFamilyIPv6, routeExcludeAddressSet, nil, false, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateLocalAddressSets(
|
||||
nft *nftables.Conn, table *nftables.Table,
|
||||
localAddresses []netip.Prefix, lastAddresses []netip.Prefix,
|
||||
) error {
|
||||
if r.enableIPv4 {
|
||||
localAddresses4 := common.Filter(localAddresses, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
})
|
||||
updateAddresses4 := common.Filter(localAddresses, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is4()
|
||||
})
|
||||
var update bool
|
||||
if len(lastAddresses) != 0 {
|
||||
if !slices.Equal(localAddresses4, updateAddresses4) {
|
||||
update = true
|
||||
}
|
||||
}
|
||||
if len(lastAddresses) == 0 || update {
|
||||
_, err := nftablesCreateIPSet(nft, table, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, nil, localAddresses4, false, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
localAddresses6 := common.Filter(localAddresses, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is6()
|
||||
})
|
||||
updateAddresses6 := common.Filter(localAddresses, func(it netip.Prefix) bool {
|
||||
return it.Addr().Is6()
|
||||
})
|
||||
var update bool
|
||||
if len(lastAddresses) != 0 {
|
||||
if !slices.Equal(localAddresses6, updateAddresses6) {
|
||||
update = true
|
||||
}
|
||||
}
|
||||
localAddresses6 = common.Filter(localAddresses6, func(it netip.Prefix) bool {
|
||||
address := it.Addr()
|
||||
return address.IsLoopback() || address.IsGlobalUnicast() && !address.IsPrivate()
|
||||
})
|
||||
if len(lastAddresses) == 0 || update {
|
||||
_, err := nftablesCreateIPSet(nft, table, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, nil, localAddresses6, false, update)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error {
|
||||
if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
if chain.Type == nftables.ChainTypeRoute {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
if chain.Hooknum == nftables.ChainHookPrerouting {
|
||||
if len(r.tunOptions.IncludeInterface) > 0 {
|
||||
if len(r.tunOptions.IncludeInterface) > 1 {
|
||||
includeInterface := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeIFName,
|
||||
}
|
||||
err := nft.AddSet(includeInterface, common.Map(r.tunOptions.IncludeInterface, func(it string) nftables.SetElement {
|
||||
return nftables.SetElement{
|
||||
Key: nftablesIfname(it),
|
||||
}
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: includeInterface.ID,
|
||||
SetName: includeInterface.Name,
|
||||
Invert: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.IncludeInterface[0]),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.tunOptions.ExcludeInterface) > 0 {
|
||||
if len(r.tunOptions.ExcludeInterface) > 1 {
|
||||
excludeInterface := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeIFName,
|
||||
}
|
||||
err := nft.AddSet(excludeInterface, common.Map(r.tunOptions.ExcludeInterface, func(it string) nftables.SetElement {
|
||||
return nftables.SetElement{
|
||||
Key: nftablesIfname(it),
|
||||
}
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: excludeInterface.ID,
|
||||
SetName: excludeInterface.Name,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
} else {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.ExcludeInterface[0]),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(r.tunOptions.IncludeUID) > 0 {
|
||||
if len(r.tunOptions.IncludeUID) > 1 || r.tunOptions.IncludeUID[0].Start != r.tunOptions.IncludeUID[0].End {
|
||||
includeUID := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeUID,
|
||||
}
|
||||
err := nft.AddSet(includeUID, common.FlatMap(r.tunOptions.IncludeUID, func(it ranges.Range[uint32]) []nftables.SetElement {
|
||||
return []nftables.SetElement{
|
||||
{
|
||||
Key: binaryutil.NativeEndian.PutUint32(it.Start),
|
||||
},
|
||||
{
|
||||
Key: binaryutil.NativeEndian.PutUint32(it.End + 1),
|
||||
IntervalEnd: true,
|
||||
},
|
||||
}
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: includeUID.ID,
|
||||
SetName: includeUID.Name,
|
||||
Invert: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
UserData: userdata.AppendString(nil, userdata.TypeComment, "not a bug :("),
|
||||
})
|
||||
} else {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint32(r.tunOptions.IncludeUID[0].Start),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.tunOptions.ExcludeUID) > 0 {
|
||||
if len(r.tunOptions.ExcludeUID) > 1 || r.tunOptions.ExcludeUID[0].Start != r.tunOptions.ExcludeUID[0].End {
|
||||
excludeUID := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
Interval: true,
|
||||
KeyType: nftables.TypeUID,
|
||||
}
|
||||
err := nft.AddSet(excludeUID, common.FlatMap(r.tunOptions.ExcludeUID, func(it ranges.Range[uint32]) []nftables.SetElement {
|
||||
return []nftables.SetElement{
|
||||
{
|
||||
Key: binaryutil.NativeEndian.PutUint32(it.Start),
|
||||
},
|
||||
{
|
||||
Key: binaryutil.NativeEndian.PutUint32(it.End + 1),
|
||||
IntervalEnd: true,
|
||||
},
|
||||
}
|
||||
}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: excludeUID.ID,
|
||||
SetName: excludeUID.Name,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
UserData: userdata.AppendString(nil, userdata.TypeComment, "not a bug :("),
|
||||
})
|
||||
} else {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.ExcludeUID[0].Start),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.tunOptions.Inet4RouteAddress) > 0 {
|
||||
inet4RouteAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv4, nil, r.tunOptions.Inet4RouteAddress, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet4RouteAddress.ID, inet4RouteAddress.Name, nftables.TableFamilyIPv4, true)
|
||||
}
|
||||
|
||||
if len(r.tunOptions.Inet6RouteAddress) > 0 {
|
||||
inet6RouteAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv6, nil, r.tunOptions.Inet6RouteAddress, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet6RouteAddress.ID, inet6RouteAddress.Name, nftables.TableFamilyIPv6, true)
|
||||
}
|
||||
|
||||
if len(r.tunOptions.Inet4RouteExcludeAddress) > 0 {
|
||||
inet4RouteExcludeAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv4, nil, r.tunOptions.Inet4RouteExcludeAddress, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet4RouteExcludeAddress.ID, inet4RouteExcludeAddress.Name, nftables.TableFamilyIPv4, false)
|
||||
}
|
||||
|
||||
if len(r.tunOptions.Inet6RouteExcludeAddress) > 0 {
|
||||
inet6RouteExcludeAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv6, nil, r.tunOptions.Inet6RouteExcludeAddress, false, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet6RouteExcludeAddress.ID, inet6RouteExcludeAddress.Name, nftables.TableFamilyIPv6, false)
|
||||
}
|
||||
|
||||
if !r.tunOptions.EXP_DisableDNSHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) ||
|
||||
(r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT)) {
|
||||
if r.enableIPv4 {
|
||||
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv4, 5, "inet4_local_address_set")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv6, 6, "inet6_local_address_set")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if r.tunOptions.AutoRedirectMarkMode &&
|
||||
((chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeRoute) ||
|
||||
(chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeFilter)) {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_UDP},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
if r.enableIPv4 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false)
|
||||
}
|
||||
if r.enableIPv6 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, false)
|
||||
}
|
||||
|
||||
routeAddressSet := *r.routeAddressSet
|
||||
routeExcludeAddressSet := *r.routeExcludeAddressSet
|
||||
|
||||
if r.enableIPv4 && len(routeAddressSet) > 0 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 1, "inet4_route_address_set", nftables.TableFamilyIPv4, true)
|
||||
}
|
||||
|
||||
if r.enableIPv6 && len(routeAddressSet) > 0 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 2, "inet6_route_address_set", nftables.TableFamilyIPv6, true)
|
||||
}
|
||||
|
||||
if r.enableIPv4 && len(routeExcludeAddressSet) > 0 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 3, "inet4_route_exclude_address_set", nftables.TableFamilyIPv4, false)
|
||||
}
|
||||
|
||||
if r.enableIPv6 && len(routeExcludeAddressSet) > 0 {
|
||||
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 4, "inet6_route_exclude_address_set", nftables.TableFamilyIPv6, false)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyMARK,
|
||||
Register: 1,
|
||||
}, // output meta mark set myMark ct mark set meta mark
|
||||
&expr.Ct{
|
||||
Key: expr.CtKeyMARK,
|
||||
Register: 1,
|
||||
SourceRegister: true,
|
||||
},
|
||||
&expr.Counter{},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateRedirect(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
exprs ...expr.Any,
|
||||
) {
|
||||
if r.enableIPv4 && !r.enableIPv6 {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv4)},
|
||||
})
|
||||
} else if !r.enableIPv4 && r.enableIPv6 {
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nftables.TableFamilyIPv6)},
|
||||
})
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
|
||||
},
|
||||
&expr.Redir{
|
||||
RegisterProtoMin: 1,
|
||||
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictReturn,
|
||||
},
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
family nftables.TableFamily, setID uint32, setName string,
|
||||
) error {
|
||||
ipProto := &nftables.Set{
|
||||
Table: table,
|
||||
Anonymous: true,
|
||||
Constant: true,
|
||||
KeyType: nftables.TypeInetProto,
|
||||
}
|
||||
err := nft.AddSet(ipProto, []nftables.SetElement{
|
||||
{Key: []byte{unix.IPPROTO_TCP}},
|
||||
{Key: []byte{unix.IPPROTO_UDP}},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
|
||||
return it.Is4() == (family == nftables.TableFamilyIPv4)
|
||||
})
|
||||
if !dnsServer.IsValid() {
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
if HasNextAddress(r.tunOptions.Inet4Address[0], 1) {
|
||||
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
|
||||
}
|
||||
} else {
|
||||
if HasNextAddress(r.tunOptions.Inet6Address[0], 1) {
|
||||
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
|
||||
}
|
||||
}
|
||||
}
|
||||
if !dnsServer.IsValid() {
|
||||
return nil
|
||||
}
|
||||
exprs := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(family)},
|
||||
},
|
||||
}
|
||||
if chain.Hooknum == nftables.ChainHookOutput {
|
||||
// It looks like we can't hijack DNS requests sent to loopback.
|
||||
// https://serverfault.com/questions/363899/iptables-dnat-from-loopback
|
||||
// and tproxy is not available in output
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname("lo"),
|
||||
},
|
||||
)
|
||||
} else {
|
||||
if family == nftables.TableFamilyIPv4 {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
)
|
||||
} else {
|
||||
exprs = append(exprs,
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 8,
|
||||
Len: 16,
|
||||
},
|
||||
)
|
||||
}
|
||||
exprs = append(exprs, &expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: setID,
|
||||
SetName: setName,
|
||||
})
|
||||
}
|
||||
exprs = append(exprs,
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Lookup{
|
||||
SourceRegister: 1,
|
||||
SetID: ipProto.ID,
|
||||
SetName: ipProto.Name,
|
||||
},
|
||||
&expr.Payload{
|
||||
OperationType: expr.PayloadLoad,
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(53),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: dnsServer.AsSlice(),
|
||||
},
|
||||
&expr.NAT{
|
||||
Type: expr.NATTypeDestNAT,
|
||||
Family: uint32(family),
|
||||
RegAddrMin: 1,
|
||||
},
|
||||
&expr.Counter{},
|
||||
)
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: exprs,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *autoRedirect) nftablesCreateUnreachable(
|
||||
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
|
||||
) {
|
||||
if (r.enableIPv4 && r.enableIPv6) || !r.tunOptions.StrictRoute {
|
||||
return
|
||||
}
|
||||
var nfProto nftables.TableFamily
|
||||
if r.enableIPv4 {
|
||||
nfProto = nftables.TableFamilyIPv6
|
||||
} else {
|
||||
nfProto = nftables.TableFamilyIPv4
|
||||
}
|
||||
nft.AddRule(&nftables.Rule{
|
||||
Table: table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyNFPROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{uint8(nfProto)},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictDrop,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -1,103 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"github.com/sagernet/nftables"
|
||||
"github.com/sagernet/nftables/expr"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
func (r *autoRedirect) configureOpenWRTFirewall4(nft *nftables.Conn, cleanup bool) error {
|
||||
tableFW4, err := nft.ListTableOfFamily("fw4", nftables.TableFamilyINet)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !cleanup {
|
||||
ruleIif := &nftables.Rule{
|
||||
Table: tableFW4,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyIIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
ruleOif := &nftables.Rule{
|
||||
Table: tableFW4,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: nftablesIfname(r.tunOptions.Name),
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
chainForward := &nftables.Chain{
|
||||
Name: "forward",
|
||||
}
|
||||
ruleIif.Chain = chainForward
|
||||
ruleOif.Chain = chainForward
|
||||
nft.InsertRule(ruleOif)
|
||||
nft.InsertRule(ruleIif)
|
||||
chainInput := &nftables.Chain{
|
||||
Name: "input",
|
||||
}
|
||||
ruleIif.Chain = chainInput
|
||||
ruleOif.Chain = chainInput
|
||||
nft.InsertRule(ruleOif)
|
||||
nft.InsertRule(ruleIif)
|
||||
return nil
|
||||
}
|
||||
for _, chainName := range []string{"input", "forward"} {
|
||||
var rules []*nftables.Rule
|
||||
rules, err = nft.GetRules(tableFW4, &nftables.Chain{
|
||||
Name: chainName,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rule := range rules {
|
||||
if len(rule.Exprs) != 4 {
|
||||
continue
|
||||
}
|
||||
exprMeta, isMeta := rule.Exprs[0].(*expr.Meta)
|
||||
if !isMeta {
|
||||
continue
|
||||
}
|
||||
if exprMeta.Key != expr.MetaKeyIIFNAME && exprMeta.Key != expr.MetaKeyOIFNAME {
|
||||
continue
|
||||
}
|
||||
exprCmp, isCmp := rule.Exprs[1].(*expr.Cmp)
|
||||
if !isCmp {
|
||||
continue
|
||||
}
|
||||
if !slices.Equal(exprCmp.Data, nftablesIfname(r.tunOptions.Name)) {
|
||||
continue
|
||||
}
|
||||
err = nft.DelRule(rule)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,83 +0,0 @@
|
||||
//go:build linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/atomic"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type redirectServer struct {
|
||||
ctx context.Context
|
||||
handler N.TCPConnectionHandlerEx
|
||||
logger logger.Logger
|
||||
listenAddr netip.Addr
|
||||
listener *net.TCPListener
|
||||
inShutdown atomic.Bool
|
||||
}
|
||||
|
||||
func newRedirectServer(ctx context.Context, handler N.TCPConnectionHandlerEx, logger logger.Logger, listenAddr netip.Addr) *redirectServer {
|
||||
return &redirectServer{
|
||||
ctx: ctx,
|
||||
handler: handler,
|
||||
logger: logger,
|
||||
listenAddr: listenAddr,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *redirectServer) Start() error {
|
||||
var listenConfig net.ListenConfig
|
||||
// listenConfig.KeepAlive = C.TCPKeepAliveInitial
|
||||
listenConfig.KeepAlive = 10 * time.Minute
|
||||
listener, err := listenConfig.Listen(s.ctx, M.NetworkFromNetAddr("tcp", s.listenAddr), M.SocksaddrFrom(s.listenAddr, 0).String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listener = listener.(*net.TCPListener)
|
||||
go s.loopIn()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redirectServer) Close() error {
|
||||
s.inShutdown.Store(true)
|
||||
return s.listener.Close()
|
||||
}
|
||||
|
||||
func (s *redirectServer) loopIn() {
|
||||
for {
|
||||
conn, err := s.listener.AcceptTCP()
|
||||
if err != nil {
|
||||
var netError net.Error
|
||||
//nolint:staticcheck
|
||||
if errors.As(err, &netError) && netError.Temporary() {
|
||||
s.logger.Error(err)
|
||||
continue
|
||||
}
|
||||
if s.inShutdown.Load() && E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
s.listener.Close()
|
||||
s.logger.Error("serve error: ", err)
|
||||
continue
|
||||
}
|
||||
source := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
|
||||
destination, err := control.GetOriginalDestination(conn)
|
||||
if err != nil {
|
||||
_ = conn.SetLinger(0)
|
||||
_ = conn.Close()
|
||||
s.logger.Error("process redirect connection from ", source, ": invalid connection: ", err)
|
||||
continue
|
||||
}
|
||||
go s.handler.NewConnectionEx(s.ctx, conn, source, M.SocksaddrFromNetIP(destination).Unwrap(), nil)
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
//go:build !linux
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
50
stack.go
50
stack.go
@@ -2,18 +2,13 @@ package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
var ErrDrop = E.New("drop connections by rule")
|
||||
|
||||
type Stack interface {
|
||||
Start() error
|
||||
Close() error
|
||||
@@ -22,12 +17,15 @@ type Stack interface {
|
||||
type StackOptions struct {
|
||||
Context context.Context
|
||||
Tun Tun
|
||||
TunOptions Options
|
||||
UDPTimeout time.Duration
|
||||
Name string
|
||||
MTU uint32
|
||||
Inet4Address []netip.Prefix
|
||||
Inet6Address []netip.Prefix
|
||||
EndpointIndependentNat bool
|
||||
UDPTimeout int64
|
||||
Handler Handler
|
||||
Logger logger.Logger
|
||||
ForwarderBindInterface bool
|
||||
IncludeAllNetworks bool
|
||||
InterfaceFinder control.InterfaceFinder
|
||||
}
|
||||
|
||||
@@ -37,44 +35,14 @@ func NewStack(
|
||||
) (Stack, error) {
|
||||
switch stack {
|
||||
case "":
|
||||
if options.IncludeAllNetworks {
|
||||
return NewGVisor(options)
|
||||
} else if WithGVisor && !options.TunOptions.GSO {
|
||||
return NewMixed(options)
|
||||
} else {
|
||||
return NewSystem(options)
|
||||
}
|
||||
return NewSystem(options)
|
||||
case "gvisor":
|
||||
return NewGVisor(options)
|
||||
case "mixed":
|
||||
if options.IncludeAllNetworks {
|
||||
return nil, ErrIncludeAllNetworks
|
||||
}
|
||||
return NewMixed(options)
|
||||
case "system":
|
||||
if options.IncludeAllNetworks {
|
||||
return nil, ErrIncludeAllNetworks
|
||||
}
|
||||
return NewSystem(options)
|
||||
case "lwip":
|
||||
return NewLWIP(options)
|
||||
default:
|
||||
return nil, E.New("unknown stack: ", stack)
|
||||
}
|
||||
}
|
||||
|
||||
func HasNextAddress(prefix netip.Prefix, count int) bool {
|
||||
checkAddr := prefix.Addr()
|
||||
for i := 0; i < count; i++ {
|
||||
checkAddr = checkAddr.Next()
|
||||
}
|
||||
return prefix.Contains(checkAddr)
|
||||
}
|
||||
|
||||
func BroadcastAddr(inet4Address []netip.Prefix) netip.Addr {
|
||||
if len(inet4Address) == 0 {
|
||||
return netip.Addr{}
|
||||
}
|
||||
prefix := inet4Address[0]
|
||||
var broadcastAddr [4]byte
|
||||
binary.BigEndian.PutUint32(broadcastAddr[:], binary.BigEndian.Uint32(prefix.Masked().Addr().AsSlice())|^binary.BigEndian.Uint32(net.CIDRMask(prefix.Bits(), 32)))
|
||||
return netip.AddrFrom4(broadcastAddr)
|
||||
}
|
||||
|
||||
164
stack_gvisor.go
164
stack_gvisor.go
@@ -1,164 +0,0 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
)
|
||||
|
||||
const WithGVisor = true
|
||||
|
||||
const DefaultNIC tcpip.NICID = 1
|
||||
|
||||
type GVisor struct {
|
||||
ctx context.Context
|
||||
tun GVisorTun
|
||||
udpTimeout time.Duration
|
||||
broadcastAddr netip.Addr
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
stack *stack.Stack
|
||||
endpoint stack.LinkEndpoint
|
||||
}
|
||||
|
||||
type GVisorTun interface {
|
||||
Tun
|
||||
NewEndpoint() (stack.LinkEndpoint, error)
|
||||
}
|
||||
|
||||
func NewGVisor(
|
||||
options StackOptions,
|
||||
) (Stack, error) {
|
||||
gTun, isGTun := options.Tun.(GVisorTun)
|
||||
if !isGTun {
|
||||
return nil, E.New("gVisor stack is unsupported on current platform")
|
||||
}
|
||||
|
||||
gStack := &GVisor{
|
||||
ctx: options.Context,
|
||||
tun: gTun,
|
||||
udpTimeout: options.UDPTimeout,
|
||||
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
}
|
||||
return gStack, nil
|
||||
}
|
||||
|
||||
func (t *GVisor) Start() error {
|
||||
linkEndpoint, err := t.tun.NewEndpoint()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
|
||||
ipStack, err := NewGVisorStack(linkEndpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
|
||||
t.stack = ipStack
|
||||
t.endpoint = linkEndpoint
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *GVisor) Close() error {
|
||||
if t.stack == nil {
|
||||
return nil
|
||||
}
|
||||
t.endpoint.Attach(nil)
|
||||
t.stack.Close()
|
||||
for _, endpoint := range t.stack.CleanupEndpoints() {
|
||||
endpoint.Abort()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddressFromAddr(destination netip.Addr) tcpip.Address {
|
||||
if destination.Is6() {
|
||||
return tcpip.AddrFrom16(destination.As16())
|
||||
} else {
|
||||
return tcpip.AddrFrom4(destination.As4())
|
||||
}
|
||||
}
|
||||
|
||||
func AddrFromAddress(address tcpip.Address) netip.Addr {
|
||||
if address.Len() == 16 {
|
||||
return netip.AddrFrom16(address.As16())
|
||||
} else {
|
||||
return netip.AddrFrom4(address.As4())
|
||||
}
|
||||
}
|
||||
|
||||
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
|
||||
ipStack := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{
|
||||
ipv4.NewProtocol,
|
||||
ipv6.NewProtocol,
|
||||
},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
tcp.NewProtocol,
|
||||
udp.NewProtocol,
|
||||
icmp.NewProtocol4,
|
||||
icmp.NewProtocol6,
|
||||
},
|
||||
})
|
||||
err := ipStack.CreateNIC(DefaultNIC, ep)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
}
|
||||
ipStack.SetRouteTable([]tcpip.Route{
|
||||
{Destination: header.IPv4EmptySubnet, NIC: DefaultNIC},
|
||||
{Destination: header.IPv6EmptySubnet, NIC: DefaultNIC},
|
||||
})
|
||||
err = ipStack.SetSpoofing(DefaultNIC, true)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
}
|
||||
err = ipStack.SetPromiscuousMode(DefaultNIC, true)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
}
|
||||
sOpt := tcpip.TCPSACKEnabled(true)
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
|
||||
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
|
||||
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
|
||||
if runtime.GOOS == "windows" {
|
||||
tcpRecoveryOpt := tcpip.TCPRecovery(0)
|
||||
err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt)
|
||||
}
|
||||
tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{
|
||||
Min: tcpRXBufMinSize,
|
||||
Default: tcpRXBufDefSize,
|
||||
Max: tcpRXBufMaxSize,
|
||||
}
|
||||
err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
}
|
||||
tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{
|
||||
Min: tcpTXBufMinSize,
|
||||
Default: tcpTXBufDefSize,
|
||||
Max: tcpTXBufMaxSize,
|
||||
}
|
||||
err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt)
|
||||
if err != nil {
|
||||
return nil, gonet.TranslateNetstackError(err)
|
||||
}
|
||||
return ipStack, nil
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
|
||||
|
||||
type LinkEndpointFilter struct {
|
||||
stack.LinkEndpoint
|
||||
BroadcastAddress netip.Addr
|
||||
Writer N.VectorisedWriter
|
||||
}
|
||||
|
||||
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.BroadcastAddress, w.Writer})
|
||||
}
|
||||
|
||||
var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
|
||||
|
||||
type networkDispatcherFilter struct {
|
||||
stack.NetworkDispatcher
|
||||
broadcastAddress netip.Addr
|
||||
writer N.VectorisedWriter
|
||||
}
|
||||
|
||||
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
|
||||
var network header.Network
|
||||
if protocol == header.IPv4ProtocolNumber {
|
||||
if headerPackets, loaded := pkt.Data().PullUp(header.IPv4MinimumSize); loaded {
|
||||
network = header.IPv4(headerPackets)
|
||||
}
|
||||
} else {
|
||||
if headerPackets, loaded := pkt.Data().PullUp(header.IPv6MinimumSize); loaded {
|
||||
network = header.IPv6(headerPackets)
|
||||
}
|
||||
}
|
||||
if network == nil {
|
||||
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
|
||||
return
|
||||
}
|
||||
destination := AddrFromAddress(network.DestinationAddress())
|
||||
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
|
||||
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
|
||||
return
|
||||
}
|
||||
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
|
||||
}
|
||||
@@ -1,191 +0,0 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
"github.com/sagernet/gvisor/pkg/waiter"
|
||||
"github.com/sagernet/sing/common"
|
||||
)
|
||||
|
||||
type gLazyConn struct {
|
||||
tcpConn *gonet.TCPConn
|
||||
parentCtx context.Context
|
||||
stack *stack.Stack
|
||||
request *tcp.ForwarderRequest
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
handshakeDone bool
|
||||
handshakeErr error
|
||||
}
|
||||
|
||||
func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
|
||||
if c.handshakeDone {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
c.handshakeDone = true
|
||||
}()
|
||||
var (
|
||||
wq waiter.Queue
|
||||
endpoint tcpip.Endpoint
|
||||
)
|
||||
handshakeCtx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
select {
|
||||
case <-c.parentCtx.Done():
|
||||
wq.Notify(wq.Events())
|
||||
case <-handshakeCtx.Done():
|
||||
}
|
||||
}()
|
||||
endpoint, err := c.request.CreateEndpoint(&wq)
|
||||
cancel()
|
||||
if err != nil {
|
||||
gErr := gonet.TranslateNetstackError(err)
|
||||
c.handshakeErr = gErr
|
||||
c.request.Complete(true)
|
||||
return gErr
|
||||
}
|
||||
c.request.Complete(false)
|
||||
endpoint.SocketOptions().SetKeepAlive(true)
|
||||
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second)))
|
||||
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIntervalOption(15 * time.Second)))
|
||||
tcpConn := gonet.NewTCPConn(&wq, endpoint)
|
||||
c.tcpConn = tcpConn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gLazyConn) HandshakeFailure(err error) error {
|
||||
if c.handshakeDone {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
c.request.Complete(err != ErrDrop)
|
||||
c.handshakeDone = true
|
||||
c.handshakeErr = err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *gLazyConn) HandshakeSuccess() error {
|
||||
return c.HandshakeContext(context.Background())
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Read(b []byte) (n int, err error) {
|
||||
if !c.handshakeDone {
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return 0, c.handshakeErr
|
||||
}
|
||||
return c.tcpConn.Read(b)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Write(b []byte) (n int, err error) {
|
||||
if !c.handshakeDone {
|
||||
err = c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return 0, c.handshakeErr
|
||||
}
|
||||
return c.tcpConn.Write(b)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) LocalAddr() net.Addr {
|
||||
return c.localAddr
|
||||
}
|
||||
|
||||
func (c *gLazyConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *gLazyConn) SetDeadline(t time.Time) error {
|
||||
if !c.handshakeDone {
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return c.handshakeErr
|
||||
}
|
||||
return c.tcpConn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) SetReadDeadline(t time.Time) error {
|
||||
if !c.handshakeDone {
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return c.handshakeErr
|
||||
}
|
||||
return c.tcpConn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
|
||||
if !c.handshakeDone {
|
||||
err := c.HandshakeContext(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if c.handshakeErr != nil {
|
||||
return c.handshakeErr
|
||||
}
|
||||
return c.tcpConn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Close() error {
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
return nil
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
return c.tcpConn.Close()
|
||||
}
|
||||
|
||||
func (c *gLazyConn) CloseRead() error {
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
return nil
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
return c.tcpConn.CloseRead()
|
||||
}
|
||||
|
||||
func (c *gLazyConn) CloseWrite() error {
|
||||
if !c.handshakeDone {
|
||||
c.request.Complete(true)
|
||||
c.handshakeErr = net.ErrClosed
|
||||
return nil
|
||||
} else if c.handshakeErr != nil {
|
||||
return nil
|
||||
}
|
||||
return c.tcpConn.CloseRead()
|
||||
}
|
||||
|
||||
func (c *gLazyConn) ReaderReplaceable() bool {
|
||||
return c.handshakeDone && c.handshakeErr == nil
|
||||
}
|
||||
|
||||
func (c *gLazyConn) WriterReplaceable() bool {
|
||||
return c.handshakeDone && c.handshakeErr == nil
|
||||
}
|
||||
|
||||
func (c *gLazyConn) Upstream() any {
|
||||
return c.tcpConn
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TCPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
handler Handler
|
||||
forwarder *tcp.Forwarder
|
||||
}
|
||||
|
||||
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
|
||||
forwarder := &TCPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
handler: handler,
|
||||
}
|
||||
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
|
||||
return forwarder
|
||||
}
|
||||
|
||||
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
return f.forwarder.HandlePacket(id, pkt)
|
||||
}
|
||||
|
||||
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
|
||||
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
|
||||
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
|
||||
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
|
||||
if pErr != nil {
|
||||
r.Complete(pErr != ErrDrop)
|
||||
return
|
||||
}
|
||||
conn := &gLazyConn{
|
||||
parentCtx: f.ctx,
|
||||
stack: f.stack,
|
||||
request: r,
|
||||
localAddr: source.TCPAddr(),
|
||||
remoteAddr: destination.TCPAddr(),
|
||||
}
|
||||
go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil)
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build with_gvisor && !ios
|
||||
|
||||
package tun
|
||||
|
||||
import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
|
||||
const (
|
||||
tcpRXBufMinSize = tcp.MinBufferSize
|
||||
tcpRXBufDefSize = tcp.DefaultSendBufferSize
|
||||
tcpRXBufMaxSize = 8 << 20 // 8MiB
|
||||
|
||||
tcpTXBufMinSize = tcp.MinBufferSize
|
||||
tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
|
||||
tcpTXBufMaxSize = 6 << 20 // 6MiB
|
||||
)
|
||||
@@ -1,21 +0,0 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
|
||||
|
||||
const (
|
||||
// tcp{RX,TX}Buf{Min,Def,Max}Size mirror gVisor defaults. We leave these
|
||||
// unchanged on iOS for now as to not increase pressure towards the
|
||||
// NetworkExtension memory limit.
|
||||
tcpRXBufMinSize = tcp.MinBufferSize
|
||||
tcpRXBufDefSize = tcp.DefaultSendBufferSize
|
||||
tcpRXBufMaxSize = tcp.MaxBufferSize
|
||||
|
||||
tcpTXBufMinSize = tcp.MinBufferSize
|
||||
tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
|
||||
tcpTXBufMaxSize = tcp.MaxBufferSize
|
||||
)
|
||||
@@ -1,183 +0,0 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
_ "unsafe"
|
||||
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat2"
|
||||
)
|
||||
|
||||
type UDPForwarder struct {
|
||||
ctx context.Context
|
||||
stack *stack.Stack
|
||||
handler Handler
|
||||
udpNat *udpnat.Service
|
||||
}
|
||||
|
||||
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
|
||||
forwarder := &UDPForwarder{
|
||||
ctx: ctx,
|
||||
stack: stack,
|
||||
handler: handler,
|
||||
}
|
||||
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true)
|
||||
return forwarder
|
||||
}
|
||||
|
||||
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
|
||||
destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
|
||||
bufferRange := pkt.Data().AsRange()
|
||||
bufferSlices := make([][]byte, bufferRange.Size())
|
||||
rangeIterate(bufferRange, func(view *buffer.View) {
|
||||
bufferSlices = append(bufferSlices, view.AsSlice())
|
||||
})
|
||||
f.udpNat.NewPacket(bufferSlices, source, destination, pkt)
|
||||
return true
|
||||
}
|
||||
|
||||
//go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/stack.Range.iterate
|
||||
func rangeIterate(r stack.Range, fn func(*buffer.View))
|
||||
|
||||
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
|
||||
if pErr != nil {
|
||||
if pErr != ErrDrop {
|
||||
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))
|
||||
}
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
var sourceNetwork tcpip.NetworkProtocolNumber
|
||||
if source.Addr.Is4() {
|
||||
sourceNetwork = header.IPv4ProtocolNumber
|
||||
} else {
|
||||
sourceNetwork = header.IPv6ProtocolNumber
|
||||
}
|
||||
writer := &UDPBackWriter{
|
||||
stack: f.stack,
|
||||
packet: userData.(*stack.PacketBuffer).IncRef(),
|
||||
source: AddressFromAddr(source.Addr),
|
||||
sourcePort: source.Port,
|
||||
sourceNetwork: sourceNetwork,
|
||||
}
|
||||
return true, f.ctx, writer, nil
|
||||
}
|
||||
|
||||
type UDPBackWriter struct {
|
||||
access sync.Mutex
|
||||
stack *stack.Stack
|
||||
packet *stack.PacketBuffer
|
||||
source tcpip.Address
|
||||
sourcePort uint16
|
||||
sourceNetwork tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (w *UDPBackWriter) HandshakeSuccess() error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.packet != nil {
|
||||
w.packet.DecRef()
|
||||
w.packet = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *UDPBackWriter) HandshakeFailure(err error) error {
|
||||
w.access.Lock()
|
||||
defer w.access.Unlock()
|
||||
if w.packet == nil {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
wErr := gWriteUnreachable(w.stack, w.packet)
|
||||
w.packet.DecRef()
|
||||
w.packet = nil
|
||||
return wErr
|
||||
}
|
||||
|
||||
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
if !destination.IsIP() {
|
||||
return E.Cause(os.ErrInvalid, "invalid destination")
|
||||
} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
|
||||
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
|
||||
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
|
||||
return E.New("send IPv6 packet to IPv4 connection")
|
||||
}
|
||||
|
||||
defer packetBuffer.Release()
|
||||
|
||||
route, err := w.stack.FindRoute(
|
||||
DefaultNIC,
|
||||
AddressFromAddr(destination.Addr),
|
||||
w.source,
|
||||
w.sourceNetwork,
|
||||
false,
|
||||
)
|
||||
if err != nil {
|
||||
return gonet.TranslateNetstackError(err)
|
||||
}
|
||||
defer route.Release()
|
||||
|
||||
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
|
||||
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
|
||||
})
|
||||
defer packet.DecRef()
|
||||
|
||||
packet.TransportProtocolNumber = header.UDPProtocolNumber
|
||||
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
|
||||
pLen := uint16(packet.Size())
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: destination.Port,
|
||||
DstPort: w.sourcePort,
|
||||
Length: pLen,
|
||||
})
|
||||
|
||||
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
|
||||
xsum := udpHdr.CalculateChecksum(checksum.Combine(
|
||||
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
|
||||
packet.Data().Checksum(),
|
||||
))
|
||||
if xsum != math.MaxUint16 {
|
||||
xsum = ^xsum
|
||||
}
|
||||
udpHdr.SetChecksum(xsum)
|
||||
}
|
||||
|
||||
err = route.WritePacket(stack.NetworkHeaderParams{
|
||||
Protocol: header.UDPProtocolNumber,
|
||||
TTL: route.DefaultTTL(),
|
||||
TOS: 0,
|
||||
}, packet)
|
||||
if err != nil {
|
||||
route.Stats().UDP.PacketSendErrors.Increment()
|
||||
return gonet.TranslateNetstackError(err)
|
||||
}
|
||||
|
||||
route.Stats().UDP.PacketsSent.Increment()
|
||||
return nil
|
||||
}
|
||||
|
||||
func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer) error {
|
||||
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, stack.RejectIPv4WithICMPPortUnreachable, true))
|
||||
} else {
|
||||
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, stack.RejectIPv6WithICMPPortUnreachable, true))
|
||||
}
|
||||
}
|
||||
236
stack_mixed.go
236
stack_mixed.go
@@ -1,236 +0,0 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"github.com/sagernet/gvisor/pkg/buffer"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip"
|
||||
gHdr "github.com/sagernet/gvisor/pkg/tcpip/header"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/stack"
|
||||
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type Mixed struct {
|
||||
*System
|
||||
stack *stack.Stack
|
||||
endpoint *channel.Endpoint
|
||||
}
|
||||
|
||||
func NewMixed(
|
||||
options StackOptions,
|
||||
) (Stack, error) {
|
||||
system, err := NewSystem(options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Mixed{
|
||||
System: system.(*System),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Mixed) Start() error {
|
||||
err := m.System.start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
endpoint := channel.New(1024, uint32(m.mtu), "")
|
||||
ipStack, err := NewGVisorStack(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
|
||||
m.stack = ipStack
|
||||
m.endpoint = endpoint
|
||||
go m.tunLoop()
|
||||
go m.packetLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mixed) Close() error {
|
||||
if m.stack == nil {
|
||||
return nil
|
||||
}
|
||||
m.endpoint.Attach(nil)
|
||||
m.stack.Close()
|
||||
for _, endpoint := range m.stack.CleanupEndpoints() {
|
||||
endpoint.Abort()
|
||||
}
|
||||
return m.System.Close()
|
||||
}
|
||||
|
||||
func (m *Mixed) tunLoop() {
|
||||
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
|
||||
m.wintunLoop(winTun)
|
||||
return
|
||||
}
|
||||
if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
|
||||
m.frontHeadroom = linuxTUN.FrontHeadroom()
|
||||
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
|
||||
batchSize := linuxTUN.BatchSize()
|
||||
if batchSize > 1 {
|
||||
m.batchLoop(linuxTUN, batchSize)
|
||||
return
|
||||
}
|
||||
}
|
||||
packetBuffer := make([]byte, m.mtu+PacketOffset)
|
||||
for {
|
||||
n, err := m.tun.Read(packetBuffer)
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
m.logger.Error(E.Cause(err, "read packet"))
|
||||
}
|
||||
if n < header.IPv4MinimumSize {
|
||||
continue
|
||||
}
|
||||
rawPacket := packetBuffer[:n]
|
||||
packet := packetBuffer[PacketOffset:n]
|
||||
if m.processPacket(packet) {
|
||||
_, err = m.tun.Write(rawPacket)
|
||||
if err != nil {
|
||||
m.logger.Trace(E.Cause(err, "write packet"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mixed) wintunLoop(winTun WinTun) {
|
||||
for {
|
||||
packet, release, err := winTun.ReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
release()
|
||||
continue
|
||||
}
|
||||
if m.processPacket(packet) {
|
||||
_, err = winTun.Write(packet)
|
||||
if err != nil {
|
||||
m.logger.Trace(E.Cause(err, "write packet"))
|
||||
}
|
||||
}
|
||||
release()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||
packetBuffers := make([][]byte, batchSize)
|
||||
writeBuffers := make([][]byte, batchSize)
|
||||
packetSizes := make([]int, batchSize)
|
||||
for i := range packetBuffers {
|
||||
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
|
||||
}
|
||||
for {
|
||||
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
m.logger.Error(E.Cause(err, "batch read packet"))
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
packetSize := packetSizes[i]
|
||||
if packetSize < header.IPv4MinimumSize {
|
||||
continue
|
||||
}
|
||||
packetBuffer := packetBuffers[i]
|
||||
packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
|
||||
if m.processPacket(packet) {
|
||||
writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
|
||||
}
|
||||
}
|
||||
if len(writeBuffers) > 0 {
|
||||
_, err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
|
||||
if err != nil {
|
||||
m.logger.Trace(E.Cause(err, "batch write packet"))
|
||||
}
|
||||
writeBuffers = writeBuffers[:0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mixed) processPacket(packet []byte) bool {
|
||||
var (
|
||||
writeBack bool
|
||||
err error
|
||||
)
|
||||
switch ipVersion := header.IPVersion(packet); ipVersion {
|
||||
case header.IPv4Version:
|
||||
writeBack, err = m.processIPv4(packet)
|
||||
case header.IPv6Version:
|
||||
writeBack, err = m.processIPv6(packet)
|
||||
default:
|
||||
err = E.New("ip: unknown version: ", ipVersion)
|
||||
}
|
||||
if err != nil {
|
||||
m.logger.Trace(err)
|
||||
return false
|
||||
}
|
||||
return writeBack
|
||||
}
|
||||
|
||||
func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
|
||||
writeBack = true
|
||||
destination := ipHdr.DestinationAddr()
|
||||
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.TCPProtocolNumber:
|
||||
writeBack, err = m.processIPv4TCP(ipHdr, ipHdr.Payload())
|
||||
case header.UDPProtocolNumber:
|
||||
writeBack = false
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(ipHdr),
|
||||
IsForwardedPacket: true,
|
||||
})
|
||||
m.endpoint.InjectInbound(gHdr.IPv4ProtocolNumber, pkt)
|
||||
pkt.DecRef()
|
||||
return
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
|
||||
writeBack = true
|
||||
if !ipHdr.DestinationAddr().IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.TCPProtocolNumber:
|
||||
writeBack, err = m.processIPv6TCP(ipHdr, ipHdr.Payload())
|
||||
case header.UDPProtocolNumber:
|
||||
writeBack = false
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(ipHdr),
|
||||
IsForwardedPacket: true,
|
||||
})
|
||||
m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt)
|
||||
pkt.DecRef()
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *Mixed) packetLoop() {
|
||||
for {
|
||||
packet := m.endpoint.ReadContext(m.ctx)
|
||||
if packet == nil {
|
||||
break
|
||||
}
|
||||
bufio.WriteVectorised(m.tun, packet.AsSlices())
|
||||
packet.DecRef()
|
||||
}
|
||||
}
|
||||
769
stack_system.go
769
stack_system.go
@@ -1,769 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
"github.com/sagernet/sing/common/control"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/udpnat2"
|
||||
)
|
||||
|
||||
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
|
||||
|
||||
type System struct {
|
||||
ctx context.Context
|
||||
tun Tun
|
||||
tunName string
|
||||
mtu int
|
||||
handler Handler
|
||||
logger logger.Logger
|
||||
inet4Prefixes []netip.Prefix
|
||||
inet6Prefixes []netip.Prefix
|
||||
inet4ServerAddress netip.Addr
|
||||
inet4Address netip.Addr
|
||||
inet6ServerAddress netip.Addr
|
||||
inet6Address netip.Addr
|
||||
broadcastAddr netip.Addr
|
||||
udpTimeout time.Duration
|
||||
tcpListener net.Listener
|
||||
tcpListener6 net.Listener
|
||||
tcpPort uint16
|
||||
tcpPort6 uint16
|
||||
tcpNat *TCPNat
|
||||
udpNat *udpnat.Service
|
||||
bindInterface bool
|
||||
interfaceFinder control.InterfaceFinder
|
||||
frontHeadroom int
|
||||
txChecksumOffload bool
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
SourceAddress netip.Addr
|
||||
DestinationAddress netip.Addr
|
||||
SourcePort uint16
|
||||
DestinationPort uint16
|
||||
}
|
||||
|
||||
func NewSystem(options StackOptions) (Stack, error) {
|
||||
stack := &System{
|
||||
ctx: options.Context,
|
||||
tun: options.Tun,
|
||||
tunName: options.TunOptions.Name,
|
||||
mtu: int(options.TunOptions.MTU),
|
||||
udpTimeout: options.UDPTimeout,
|
||||
handler: options.Handler,
|
||||
logger: options.Logger,
|
||||
inet4Prefixes: options.TunOptions.Inet4Address,
|
||||
inet6Prefixes: options.TunOptions.Inet6Address,
|
||||
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
|
||||
bindInterface: options.ForwarderBindInterface,
|
||||
interfaceFinder: options.InterfaceFinder,
|
||||
}
|
||||
if len(options.TunOptions.Inet4Address) > 0 {
|
||||
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
|
||||
return nil, E.New("need one more IPv4 address in first prefix for system stack")
|
||||
}
|
||||
stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
|
||||
stack.inet4Address = stack.inet4ServerAddress.Next()
|
||||
}
|
||||
if len(options.TunOptions.Inet6Address) > 0 {
|
||||
if !HasNextAddress(options.TunOptions.Inet6Address[0], 1) {
|
||||
return nil, E.New("need one more IPv6 address in first prefix for system stack")
|
||||
}
|
||||
stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
|
||||
stack.inet6Address = stack.inet6ServerAddress.Next()
|
||||
}
|
||||
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
|
||||
return nil, E.New("missing interface address")
|
||||
}
|
||||
return stack, nil
|
||||
}
|
||||
|
||||
func (s *System) Close() error {
|
||||
return common.Close(
|
||||
s.tcpListener,
|
||||
s.tcpListener6,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *System) Start() error {
|
||||
err := s.start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.tunLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) start() error {
|
||||
_ = fixWindowsFirewall()
|
||||
var listener net.ListenConfig
|
||||
if s.bindInterface {
|
||||
listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
|
||||
bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true)
|
||||
if bindErr != nil {
|
||||
s.logger.Warn("bind forwarder to interface: ", bindErr)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
var tcpListener net.Listener
|
||||
var err error
|
||||
if s.inet4Address.IsValid() {
|
||||
for i := 0; i < 3; i++ {
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
|
||||
if !retryableListenError(err) {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.tcpListener = tcpListener
|
||||
s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
|
||||
go s.acceptLoop(tcpListener)
|
||||
}
|
||||
if s.inet6Address.IsValid() {
|
||||
for i := 0; i < 3; i++ {
|
||||
tcpListener, err = listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
|
||||
if !retryableListenError(err) {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.tcpListener6 = tcpListener
|
||||
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
|
||||
go s.acceptLoop(tcpListener)
|
||||
}
|
||||
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
|
||||
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) tunLoop() {
|
||||
if winTun, isWinTun := s.tun.(WinTun); isWinTun {
|
||||
s.wintunLoop(winTun)
|
||||
return
|
||||
}
|
||||
if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
|
||||
s.frontHeadroom = linuxTUN.FrontHeadroom()
|
||||
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
|
||||
batchSize := linuxTUN.BatchSize()
|
||||
if batchSize > 1 {
|
||||
s.batchLoop(linuxTUN, batchSize)
|
||||
return
|
||||
}
|
||||
}
|
||||
packetBuffer := make([]byte, s.mtu+PacketOffset)
|
||||
for {
|
||||
n, err := s.tun.Read(packetBuffer)
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
s.logger.Error(E.Cause(err, "read packet"))
|
||||
}
|
||||
if n < header.IPv4MinimumSize {
|
||||
continue
|
||||
}
|
||||
rawPacket := packetBuffer[:n]
|
||||
packet := packetBuffer[PacketOffset:n]
|
||||
if s.processPacket(packet) {
|
||||
_, err = s.tun.Write(rawPacket)
|
||||
if err != nil {
|
||||
s.logger.Trace(E.Cause(err, "write packet"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) wintunLoop(winTun WinTun) {
|
||||
for {
|
||||
packet, release, err := winTun.ReadPacket()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(packet) < header.IPv4MinimumSize {
|
||||
release()
|
||||
continue
|
||||
}
|
||||
if s.processPacket(packet) {
|
||||
_, err = winTun.Write(packet)
|
||||
if err != nil {
|
||||
s.logger.Trace(E.Cause(err, "write packet"))
|
||||
}
|
||||
}
|
||||
release()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
|
||||
packetBuffers := make([][]byte, batchSize)
|
||||
writeBuffers := make([][]byte, batchSize)
|
||||
packetSizes := make([]int, batchSize)
|
||||
for i := range packetBuffers {
|
||||
packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
|
||||
}
|
||||
for {
|
||||
n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
|
||||
if err != nil {
|
||||
if E.IsClosed(err) {
|
||||
return
|
||||
}
|
||||
s.logger.Error(E.Cause(err, "batch read packet"))
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
packetSize := packetSizes[i]
|
||||
if packetSize < header.IPv4MinimumSize {
|
||||
continue
|
||||
}
|
||||
packetBuffer := packetBuffers[i]
|
||||
packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
|
||||
if s.processPacket(packet) {
|
||||
writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
|
||||
}
|
||||
}
|
||||
if len(writeBuffers) > 0 {
|
||||
_, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
|
||||
if err != nil {
|
||||
s.logger.Trace(E.Cause(err, "batch write packet"))
|
||||
}
|
||||
writeBuffers = writeBuffers[:0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) processPacket(packet []byte) bool {
|
||||
var (
|
||||
writeBack bool
|
||||
err error
|
||||
)
|
||||
switch ipVersion := header.IPVersion(packet); ipVersion {
|
||||
case header.IPv4Version:
|
||||
writeBack, err = s.processIPv4(packet)
|
||||
case header.IPv6Version:
|
||||
writeBack, err = s.processIPv6(packet)
|
||||
default:
|
||||
err = E.New("ip: unknown version: ", ipVersion)
|
||||
}
|
||||
if err != nil {
|
||||
s.logger.Trace(err)
|
||||
return false
|
||||
}
|
||||
return writeBack
|
||||
}
|
||||
|
||||
func (s *System) acceptLoop(listener net.Listener) {
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port
|
||||
session := s.tcpNat.LookupBack(connPort)
|
||||
if session == nil {
|
||||
s.logger.Trace(E.New("unknown session with port ", connPort))
|
||||
continue
|
||||
}
|
||||
destination := M.SocksaddrFromNetIP(session.Destination)
|
||||
if destination.Addr.Is4() {
|
||||
for _, prefix := range s.inet4Prefixes {
|
||||
if prefix.Contains(destination.Addr) {
|
||||
destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, prefix := range s.inet6Prefixes {
|
||||
if prefix.Contains(destination.Addr) {
|
||||
destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
go s.handler.NewConnectionEx(s.ctx, conn, M.SocksaddrFromNetIP(session.Source), destination, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
|
||||
destination := ipHdr.DestinationAddr()
|
||||
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
writeBack = true
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.TCPProtocolNumber:
|
||||
writeBack, err = s.processIPv4TCP(ipHdr, ipHdr.Payload())
|
||||
case header.UDPProtocolNumber:
|
||||
writeBack = false
|
||||
err = s.processIPv4UDP(ipHdr, ipHdr.Payload())
|
||||
case header.ICMPv4ProtocolNumber:
|
||||
err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
|
||||
if !ipHdr.DestinationAddr().IsGlobalUnicast() {
|
||||
return
|
||||
}
|
||||
writeBack = true
|
||||
switch ipHdr.TransportProtocol() {
|
||||
case header.TCPProtocolNumber:
|
||||
writeBack, err = s.processIPv6TCP(ipHdr, ipHdr.Payload())
|
||||
case header.UDPProtocolNumber:
|
||||
err = s.processIPv6UDP(ipHdr, ipHdr.Payload())
|
||||
case header.ICMPv6ProtocolNumber:
|
||||
err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, error) {
|
||||
source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort())
|
||||
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
|
||||
if !destination.Addr().IsGlobalUnicast() {
|
||||
return false, nil
|
||||
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
|
||||
session := s.tcpNat.LookupBack(destination.Port())
|
||||
if session == nil {
|
||||
return false, E.New("ipv4: tcp: session not found: ", destination.Port())
|
||||
}
|
||||
ipHdr.SetSourceAddr(session.Destination.Addr())
|
||||
tcpHdr.SetSourcePort(session.Destination.Port())
|
||||
ipHdr.SetDestinationAddr(session.Source.Addr())
|
||||
tcpHdr.SetDestinationPort(session.Source.Port())
|
||||
} else {
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
if err == ErrDrop {
|
||||
return false, nil
|
||||
} else {
|
||||
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet4Address)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort)
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
tcpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) error {
|
||||
frontHeadroom := s.frontHeadroom + PacketOffset
|
||||
newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.TCPMinimumSize)
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.TCPMinimumSize)
|
||||
ipHdr := header.IPv4(newPacket.Bytes())
|
||||
ipHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(newPacket.Len()),
|
||||
Protocol: uint8(header.TCPProtocolNumber),
|
||||
SrcAddr: origIPHdr.DestinationAddr(),
|
||||
DstAddr: origIPHdr.SourceAddr(),
|
||||
})
|
||||
tcpHdr := header.TCP(ipHdr.Payload())
|
||||
fields := header.TCPFields{
|
||||
SrcPort: origTCPHdr.DestinationPort(),
|
||||
DstPort: origTCPHdr.SourcePort(),
|
||||
DataOffset: header.TCPMinimumSize,
|
||||
Flags: header.TCPFlagRst,
|
||||
}
|
||||
if origTCPHdr.Flags()&header.TCPFlagAck != 0 {
|
||||
fields.SeqNum = origTCPHdr.AckNumber()
|
||||
} else {
|
||||
fields.Flags |= header.TCPFlagAck
|
||||
ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload()))
|
||||
if origTCPHdr.Flags()&header.TCPFlagSyn != 0 {
|
||||
ackNum++
|
||||
}
|
||||
if origTCPHdr.Flags()&header.TCPFlagFin != 0 {
|
||||
ackNum++
|
||||
}
|
||||
fields.AckNum = ackNum
|
||||
}
|
||||
tcpHdr.Encode(&fields)
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
} else {
|
||||
newPacket.Advance(-s.frontHeadroom)
|
||||
}
|
||||
return common.Error(s.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, error) {
|
||||
source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort())
|
||||
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
|
||||
if !destination.Addr().IsGlobalUnicast() {
|
||||
return false, nil
|
||||
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
|
||||
session := s.tcpNat.LookupBack(destination.Port())
|
||||
if session == nil {
|
||||
return false, E.New("ipv6: tcp: session not found: ", destination.Port())
|
||||
}
|
||||
ipHdr.SetSourceAddr(session.Destination.Addr())
|
||||
tcpHdr.SetSourcePort(session.Destination.Port())
|
||||
ipHdr.SetDestinationAddr(session.Source.Addr())
|
||||
tcpHdr.SetDestinationPort(session.Source.Port())
|
||||
} else {
|
||||
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
|
||||
if err != nil {
|
||||
if err == ErrDrop {
|
||||
return false, nil
|
||||
} else {
|
||||
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
|
||||
}
|
||||
}
|
||||
ipHdr.SetSourceAddr(s.inet6Address)
|
||||
tcpHdr.SetSourcePort(natPort)
|
||||
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
|
||||
tcpHdr.SetDestinationPort(s.tcpPort6)
|
||||
}
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(0)
|
||||
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
tcpHdr.SetChecksum(0)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) error {
|
||||
frontHeadroom := s.frontHeadroom + PacketOffset
|
||||
newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.TCPMinimumSize)
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.TCPMinimumSize)
|
||||
ipHdr := header.IPv6(newPacket.Bytes())
|
||||
ipHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(header.TCPMinimumSize),
|
||||
TransportProtocol: header.TCPProtocolNumber,
|
||||
SrcAddr: origIPHdr.DestinationAddr(),
|
||||
DstAddr: origIPHdr.SourceAddr(),
|
||||
})
|
||||
tcpHdr := header.TCP(ipHdr.Payload())
|
||||
fields := header.TCPFields{
|
||||
SrcPort: origTCPHdr.DestinationPort(),
|
||||
DstPort: origTCPHdr.SourcePort(),
|
||||
DataOffset: header.TCPMinimumSize,
|
||||
Flags: header.TCPFlagRst,
|
||||
}
|
||||
if origTCPHdr.Flags()&header.TCPFlagAck != 0 {
|
||||
fields.SeqNum = origTCPHdr.AckNumber()
|
||||
} else {
|
||||
fields.Flags |= header.TCPFlagAck
|
||||
ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload()))
|
||||
if origTCPHdr.Flags()&header.TCPFlagSyn != 0 {
|
||||
ackNum++
|
||||
}
|
||||
if origTCPHdr.Flags()&header.TCPFlagFin != 0 {
|
||||
ackNum++
|
||||
}
|
||||
fields.AckNum = ackNum
|
||||
}
|
||||
tcpHdr.Encode(&fields)
|
||||
if !s.txChecksumOffload {
|
||||
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
|
||||
}
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||
} else {
|
||||
newPacket.Advance(-s.frontHeadroom)
|
||||
}
|
||||
return common.Error(s.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
func (s *System) processIPv4UDP(ipHdr header.IPv4, udpHdr header.UDP) error {
|
||||
if ipHdr.Flags()&header.IPv4FlagMoreFragments != 0 {
|
||||
return E.New("ipv4: fragment dropped")
|
||||
}
|
||||
if ipHdr.FragmentOffset() != 0 {
|
||||
return E.New("ipv4: udp: fragment dropped")
|
||||
}
|
||||
source := M.SocksaddrFrom(ipHdr.SourceAddr(), udpHdr.SourcePort())
|
||||
destination := M.SocksaddrFrom(ipHdr.DestinationAddr(), udpHdr.DestinationPort())
|
||||
if !destination.Addr.IsGlobalUnicast() {
|
||||
return nil
|
||||
}
|
||||
s.udpNat.NewPacket([][]byte{udpHdr.Payload()}, source, destination, ipHdr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
|
||||
source := M.SocksaddrFrom(ipHdr.SourceAddr(), udpHdr.SourcePort())
|
||||
destination := M.SocksaddrFrom(ipHdr.DestinationAddr(), udpHdr.DestinationPort())
|
||||
if !destination.Addr.IsGlobalUnicast() {
|
||||
return nil
|
||||
}
|
||||
s.udpNat.NewPacket([][]byte{udpHdr.Payload()}, source, destination, ipHdr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
|
||||
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
|
||||
if pErr != nil {
|
||||
if pErr != ErrDrop {
|
||||
if source.IsIPv4() {
|
||||
ipHdr := userData.(header.IPv4)
|
||||
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
|
||||
} else {
|
||||
ipHdr := userData.(header.IPv6)
|
||||
s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable)
|
||||
}
|
||||
}
|
||||
return false, nil, nil, nil
|
||||
}
|
||||
var writer N.PacketWriter
|
||||
if source.IsIPv4() {
|
||||
packet := userData.(header.IPv4)
|
||||
headerLen := packet.HeaderLength() + header.UDPMinimumSize
|
||||
headerCopy := make([]byte, headerLen)
|
||||
copy(headerCopy, packet[:headerLen])
|
||||
writer = &systemUDPPacketWriter4{
|
||||
s.tun,
|
||||
s.frontHeadroom + PacketOffset,
|
||||
headerCopy,
|
||||
source.AddrPort(),
|
||||
s.txChecksumOffload,
|
||||
}
|
||||
} else {
|
||||
packet := userData.(header.IPv6)
|
||||
headerLen := len(packet) - int(packet.PayloadLength()) + header.UDPMinimumSize
|
||||
headerCopy := make([]byte, headerLen)
|
||||
copy(headerCopy, packet[:headerLen])
|
||||
writer = &systemUDPPacketWriter6{
|
||||
s.tun,
|
||||
s.frontHeadroom + PacketOffset,
|
||||
headerCopy,
|
||||
source.AddrPort(),
|
||||
s.txChecksumOffload,
|
||||
}
|
||||
}
|
||||
return true, s.ctx, writer, nil
|
||||
}
|
||||
|
||||
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error {
|
||||
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
|
||||
return nil
|
||||
}
|
||||
icmpHdr.SetType(header.ICMPv4EchoReply)
|
||||
sourceAddress := ipHdr.SourceAddr()
|
||||
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
||||
ipHdr.SetDestinationAddr(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error {
|
||||
frontHeadroom := s.frontHeadroom + PacketOffset
|
||||
mtu := s.mtu
|
||||
const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize
|
||||
if mtu > maxIPData {
|
||||
mtu = maxIPData
|
||||
}
|
||||
available := mtu - header.ICMPv4MinimumSize
|
||||
if available < len(ipHdr)+header.ICMPv4MinimumErrorPayloadSize {
|
||||
return nil
|
||||
}
|
||||
payload := ipHdr
|
||||
if len(payload) > available {
|
||||
payload = payload[:available]
|
||||
}
|
||||
newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(payload))
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.ICMPv4MinimumSize+len(payload))
|
||||
newIPHdr := header.IPv4(newPacket.Bytes())
|
||||
newIPHdr.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(newPacket.Len()),
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: ipHdr.DestinationAddr(),
|
||||
DstAddr: ipHdr.SourceAddr(),
|
||||
})
|
||||
newIPHdr.SetChecksum(^newIPHdr.CalculateChecksum())
|
||||
icmpHdr := header.ICMPv4(newIPHdr.Payload())
|
||||
icmpHdr.SetType(header.ICMPv4DstUnreachable)
|
||||
icmpHdr.SetCode(code)
|
||||
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
|
||||
copy(icmpHdr.Payload(), payload)
|
||||
if PacketOffset > 0 {
|
||||
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
|
||||
} else {
|
||||
newPacket.Advance(-s.frontHeadroom)
|
||||
}
|
||||
return common.Error(s.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error {
|
||||
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
|
||||
return nil
|
||||
}
|
||||
icmpHdr.SetType(header.ICMPv6EchoReply)
|
||||
sourceAddress := ipHdr.SourceAddr()
|
||||
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
|
||||
ipHdr.SetDestinationAddr(sourceAddress)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr,
|
||||
Src: ipHdr.SourceAddress(),
|
||||
Dst: ipHdr.DestinationAddress(),
|
||||
}))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error {
|
||||
frontHeadroom := s.frontHeadroom + PacketOffset
|
||||
mtu := s.mtu
|
||||
const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize
|
||||
if mtu > maxIPv6Data {
|
||||
mtu = maxIPv6Data
|
||||
}
|
||||
available := mtu - header.ICMPv6ErrorHeaderSize
|
||||
if available < header.IPv6MinimumSize {
|
||||
return nil
|
||||
}
|
||||
payload := ipHdr
|
||||
if len(payload) > available {
|
||||
payload = payload[:available]
|
||||
}
|
||||
newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + len(payload))
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.ICMPv6DstUnreachableMinimumSize+len(payload))
|
||||
newIPHdr := header.IPv6(newPacket.Bytes())
|
||||
newIPHdr.Encode(&header.IPv6Fields{
|
||||
PayloadLength: uint16(header.ICMPv6DstUnreachableMinimumSize + len(payload)),
|
||||
TransportProtocol: header.ICMPv6ProtocolNumber,
|
||||
SrcAddr: ipHdr.DestinationAddr(),
|
||||
DstAddr: ipHdr.SourceAddr(),
|
||||
})
|
||||
icmpHdr := header.ICMPv6(newIPHdr.Payload())
|
||||
icmpHdr.SetType(header.ICMPv6DstUnreachable)
|
||||
icmpHdr.SetCode(code)
|
||||
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
|
||||
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
|
||||
Src: newIPHdr.SourceAddress(),
|
||||
Dst: newIPHdr.DestinationAddress(),
|
||||
PayloadCsum: checksum.Checksum(payload, 0),
|
||||
PayloadLen: len(payload),
|
||||
}))
|
||||
copy(icmpHdr.Payload(), payload)
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||
} else {
|
||||
newPacket.Advance(-s.frontHeadroom)
|
||||
}
|
||||
return common.Error(s.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
type systemUDPPacketWriter4 struct {
|
||||
tun Tun
|
||||
frontHeadroom int
|
||||
header []byte
|
||||
source netip.AddrPort
|
||||
txChecksumOffload bool
|
||||
}
|
||||
|
||||
func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(w.frontHeadroom, 0)
|
||||
newPacket.Write(w.header)
|
||||
newPacket.Write(buffer.Bytes())
|
||||
ipHdr := header.IPv4(newPacket.Bytes())
|
||||
ipHdr.SetTotalLength(uint16(newPacket.Len()))
|
||||
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
|
||||
ipHdr.SetSourceAddr(destination.Addr)
|
||||
udpHdr := header.UDP(ipHdr.Payload())
|
||||
udpHdr.SetDestinationPort(udpHdr.SourcePort())
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
udpHdr.SetChecksum(0)
|
||||
}
|
||||
ipHdr.SetChecksum(0)
|
||||
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
|
||||
} else {
|
||||
newPacket.Advance(-w.frontHeadroom)
|
||||
}
|
||||
return common.Error(w.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
|
||||
type systemUDPPacketWriter6 struct {
|
||||
tun Tun
|
||||
frontHeadroom int
|
||||
header []byte
|
||||
source netip.AddrPort
|
||||
txChecksumOffload bool
|
||||
}
|
||||
|
||||
func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
|
||||
defer newPacket.Release()
|
||||
newPacket.Resize(w.frontHeadroom, 0)
|
||||
newPacket.Write(w.header)
|
||||
newPacket.Write(buffer.Bytes())
|
||||
ipHdr := header.IPv6(newPacket.Bytes())
|
||||
udpLen := uint16(header.UDPMinimumSize + buffer.Len())
|
||||
ipHdr.SetPayloadLength(udpLen)
|
||||
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
|
||||
ipHdr.SetSourceAddr(destination.Addr)
|
||||
udpHdr := header.UDP(ipHdr.Payload())
|
||||
udpHdr.SetDestinationPort(udpHdr.SourcePort())
|
||||
udpHdr.SetSourcePort(destination.Port)
|
||||
udpHdr.SetLength(udpLen)
|
||||
if !w.txChecksumOffload {
|
||||
udpHdr.SetChecksum(0)
|
||||
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
|
||||
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
|
||||
)))
|
||||
} else {
|
||||
udpHdr.SetChecksum(0)
|
||||
}
|
||||
if PacketOffset > 0 {
|
||||
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
|
||||
} else {
|
||||
newPacket.Advance(-w.frontHeadroom)
|
||||
}
|
||||
return common.Error(w.tun.Write(newPacket.Bytes()))
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package tun
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func fixWindowsFirewall() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func retryableListenError(err error) bool {
|
||||
return errors.Is(err, unix.EADDRNOTAVAIL)
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
package tun
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-tun/internal/gtcpip/header"
|
||||
)
|
||||
|
||||
func PacketIPVersion(packet []byte) int {
|
||||
return header.IPVersion(packet)
|
||||
}
|
||||
|
||||
func PacketFillHeader(packet []byte, ipVersion int) {
|
||||
if PacketOffset > 0 {
|
||||
switch ipVersion {
|
||||
case header.IPv4Version:
|
||||
packet[3] = syscall.AF_INET
|
||||
case header.IPv6Version:
|
||||
packet[3] = syscall.AF_INET6
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func PacketDestination(packet []byte) netip.Addr {
|
||||
switch ipVersion := header.IPVersion(packet); ipVersion {
|
||||
case header.IPv4Version:
|
||||
return header.IPv4(packet).DestinationAddr()
|
||||
case header.IPv6Version:
|
||||
return header.IPv6(packet).DestinationAddr()
|
||||
default:
|
||||
return netip.Addr{}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user