Compare commits

..

1 Commits

Author SHA1 Message Date
世界
ead31a5b8d Update gVisor to 20230621.0 2023-06-26 19:32:18 +08:00
117 changed files with 3069 additions and 15674 deletions

19
.github/renovate.json vendored
View File

@@ -1,19 +0,0 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"commitMessagePrefix": "[dependencies]",
"extends": [
"config:base",
":disableRateLimiting"
],
"golang": {
"enabled": false
},
"packageRules": [
{
"matchManagers": [
"github-actions"
],
"groupName": "github-actions"
}
]
}

44
.github/workflows/debug.yml vendored Normal file
View File

@@ -0,0 +1,44 @@
name: Debug build
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- dev
jobs:
build:
name: Debug build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Get latest go version
id: version
run: |
echo ::set-output name=go_version::$(curl -s https://raw.githubusercontent.com/actions/go-versions/main/versions-manifest.json | grep -oE '"version": "[0-9]{1}.[0-9]{1,}(.[0-9]{1,})?"' | head -1 | cut -d':' -f2 | sed 's/ //g; s/"//g')
- name: Setup Go
uses: actions/setup-go@v2
with:
go-version: ${{ steps.version.outputs.go_version }}
- name: Add cache to Go proxy
run: |
version=`git rev-parse HEAD`
mkdir build
pushd build
go mod init build
go get -v github.com/sagernet/sing-tun@$version
popd
continue-on-error: true
- name: Build
run: |
go build -v .

View File

@@ -1,40 +0,0 @@
name: lint
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/lint.yml'
pull_request:
branches:
- main
- dev
jobs:
build:
name: Build
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
- name: Cache go module
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
key: go-${{ hashFiles('**/go.sum') }}
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: latest
args: .

View File

@@ -1,112 +0,0 @@
name: test
on:
push:
branches:
- main
- dev
paths-ignore:
- '**.md'
- '.github/**'
- '!.github/workflows/debug.yml'
pull_request:
branches:
- main
- dev
jobs:
build:
name: Linux
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
- name: Build
run: |
make test
build_go120:
name: Linux (Go 1.20)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.20
continue-on-error: true
- name: Build
run: |
make test
build_go121:
name: Linux (Go 1.21)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.21
continue-on-error: true
- name: Build
run: |
make test
build_go122:
name: Linux (Go 1.22)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ~1.22
continue-on-error: true
- name: Build
run: |
make test
build_windows:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test
build_darwin:
name: macOS
runs-on: macos-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: ^1.23
continue-on-error: true
- name: Build
run: |
make test

1
.gitignore vendored
View File

@@ -1,3 +1,2 @@
/.idea/
/vendor/
.DS_Store

View File

@@ -3,18 +3,14 @@ linters:
enable:
- gofumpt
- govet
- gci
# - gci
- staticcheck
- paralleltest
- ineffassign
linters-settings:
gci:
custom-order: true
sections:
- standard
- prefix(github.com/sagernet/)
- default
run:
go: "1.23"
# gci:
# sections:
# - standard
# - prefix(github.com/sagernet/sing)
# - default
staticcheck:
go: '1.19'

View File

@@ -5,7 +5,6 @@ build:
GOOS=linux GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=linux GOARCH=386 go build -v -tags with_gvisor .
GOOS=linux GOARCH=arm go build -v -tags with_gvisor .
GOOS=android GOARCH=arm64 go build -v -tags with_gvisor .
GOOS=windows GOARCH=amd64 go build -v -tags with_gvisor .
fmt:
@@ -28,6 +27,4 @@ lint_install:
go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
test:
go build -v .
go test -bench=. ./internal/checksum_test
#go test -v .
go test -v .

33
go.mod
View File

@@ -1,28 +1,21 @@
module github.com/sagernet/sing-tun
go 1.20
go 1.18
require (
github.com/go-ole/go-ole v1.3.0
github.com/google/btree v1.1.3
github.com/sagernet/fswatch v0.1.1
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a
github.com/sagernet/nftables v0.3.0-beta.4
github.com/sagernet/sing v0.6.0-beta.2
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8
golang.org/x/net v0.26.0
golang.org/x/sys v0.26.0
github.com/fsnotify/fsnotify v1.6.0
github.com/go-ole/go-ole v1.2.6
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61
github.com/sagernet/gvisor v0.0.0-20230626113001-7d91dec8c61c
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97
github.com/sagernet/sing v0.2.7
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9
golang.org/x/net v0.11.0
golang.org/x/sys v0.9.0
)
require (
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/time v0.7.0 // indirect
github.com/google/btree v1.1.2 // indirect
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 // indirect
golang.org/x/time v0.3.0 // indirect
)

72
go.sum
View File

@@ -1,43 +1,29 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA=
github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w=
github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g=
github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw=
github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U=
github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/sagernet/fswatch v0.1.1 h1:YqID+93B7VRfqIH3PArW/XpJv5H4OLEVWDfProGoRQs=
github.com/sagernet/fswatch v0.1.1/go.mod h1:nz85laH0mkQqJfaOrqPpkwtU1znMFNVTpT/5oRsVz/o=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff h1:mlohw3360Wg1BNGook/UHnISXhUx4Gd/3tVLs5T0nSs=
github.com/sagernet/gvisor v0.0.0-20241123041152-536d05261cff/go.mod h1:ehZwnT2UpmOWAHFL48XdBhnd4Qu4hN2O3Ji0us3ZHMw=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZNjr6sGeT00J8uU7JF4cNUdb44/Duis=
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/sing v0.6.0-beta.2 h1:Dcutp3kxrsZes9q3oTiHQhYYjQvDn5rwp1OI9fDLYwQ=
github.com/sagernet/sing v0.6.0-beta.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 h1:yixxcjnhBmY0nkL253HFVIm0JsFHwrHdT3Yh6szTnfY=
golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8/go.mod h1:jj3sYF3dwk5D+ghuXyeI3r5MFf+NT2An6/9dOA95KSI=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=
golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M=
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ=
golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61 h1:5+m7c6AkmAylhauulqN/c5dnh8/KssrE9c93TQrXldA=
github.com/sagernet/go-tun2socks v1.16.12-0.20220818015926-16cb67876a61/go.mod h1:QUQ4RRHD6hGGHdFMEtR8T2P6GS6R3D/CXKdaYHKKXms=
github.com/sagernet/gvisor v0.0.0-20230626113001-7d91dec8c61c h1:YzXn4fVpRR9SL/l8S/30iYwNZwlBZY7FZvg8eHRNMdc=
github.com/sagernet/gvisor v0.0.0-20230626113001-7d91dec8c61c/go.mod h1:1JUiV7nGuf++YFm9eWZ8q2lrwHmhcUGzptMl/vL1+LA=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97 h1:iL5gZI3uFp0X6EslacyapiRz7LLSJyr4RajF/BhMVyE=
github.com/sagernet/netlink v0.0.0-20220905062125-8043b4a9aa97/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY=
github.com/sagernet/sing v0.2.7 h1:cOy0FfPS8q7m0aJ51wS7LRQAGc9wF+fWhHtBDj99wy8=
github.com/sagernet/sing v0.2.7/go.mod h1:Ta8nHnDLAwqySzKhGoKk4ZIB+vJ3GTKj7UPrWYvM+4w=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9 h1:rc/CcqLH3lh8n+csdOuDfP+NuykE0U6AeYSJJHKDgSg=
github.com/scjalliance/comshim v0.0.0-20230315213746-5e51f40bd3b9/go.mod h1:a/83NAfUXvEuLpmxDssAXxgUgrEy12MId3Wd7OTs76s=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74 h1:gga7acRE695APm9hlsSMoOoE65U4/TcqNj90mc69Rlg=
github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
golang.org/x/net v0.11.0 h1:Gi2tvZIJyBtO9SDr1q9h5hEQCp/4L2RQ+ar0qjx2oNU=
golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=

209
gvisor.go Normal file
View File

@@ -0,0 +1,209 @@
//go:build with_gvisor
package tun
import (
"context"
"net/netip"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
)
const WithGVisor = true
const defaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
tun GVisorTun
tunMtu uint32
endpointIndependentNat bool
udpTimeout int64
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
}
type GVisorTun interface {
Tun
NewEndpoint() (stack.LinkEndpoint, error)
}
func NewGVisor(
options StackOptions,
) (Stack, error) {
gTun, isGTun := options.Tun.(GVisorTun)
if !isGTun {
return nil, E.New("gVisor stack is unsupported on current platform")
}
gStack := &GVisor{
ctx: options.Context,
tun: gTun,
tunMtu: options.MTU,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
handler: options.Handler,
logger: options.Logger,
}
return gStack, nil
}
func (t *GVisor) Start() error {
linkEndpoint, err := t.tun.NewEndpoint()
if err != nil {
return err
}
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
tErr := ipStack.CreateNIC(defaultNIC, linkEndpoint)
if tErr != nil {
return E.New("create nic: ", wrapStackError(tErr))
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: defaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: defaultNIC},
})
ipStack.SetSpoofing(defaultNIC, true)
ipStack.SetPromiscuousMode(defaultNIC, true)
bufSize := 20 * 1024
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPReceiveBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpip.TCPSendBufferSizeRangeOption{
Min: 1,
Default: bufSize,
Max: bufSize,
})
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
tcpForwarder := tcp.NewForwarder(ipStack, 0, 1024, func(r *tcp.ForwarderRequest) {
var wq waiter.Queue
handshakeCtx, cancel := context.WithCancel(context.Background())
go func() {
select {
case <-t.ctx.Done():
wq.Notify(wq.Events())
case <-handshakeCtx.Done():
}
}()
endpoint, err := r.CreateEndpoint(&wq)
cancel()
if err != nil {
r.Complete(true)
return
}
r.Complete(false)
endpoint.SocketOptions().SetKeepAlive(true)
keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
endpoint.SetSockOpt(&keepAliveIdle)
keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
endpoint.SetSockOpt(&keepAliveInterval)
tcpConn := gonet.NewTCPConn(&wq, endpoint)
lAddr := tcpConn.RemoteAddr()
rAddr := tcpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
tcpConn.Close()
return
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
hErr := t.handler.NewConnection(t.ctx, &gTCPConn{tcpConn}, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
if !t.endpointIndependentNat {
udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
var wq waiter.Queue
endpoint, err := request.CreateEndpoint(&wq)
if err != nil {
return
}
udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
lAddr := udpConn.RemoteAddr()
rAddr := udpConn.LocalAddr()
if lAddr == nil || rAddr == nil {
endpoint.Abort()
return
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
ctx, conn := canceler.NewPacketConn(t.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(&gUDPConn{udpConn}), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(t.udpTimeout)*time.Second)
hErr := t.handler.NewPacketConnection(ctx, conn, metadata)
if hErr != nil {
endpoint.Abort()
}
}()
})
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
} else {
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
}
t.stack = ipStack
t.endpoint = linkEndpoint
return nil
}
func (t *GVisor) Close() error {
t.endpoint.Attach(nil)
t.stack.Close()
for _, endpoint := range t.stack.CleanupEndpoints() {
endpoint.Abort()
}
return nil
}
func AddressFromAddr(destination netip.Addr) tcpip.Address {
if destination.Is6() {
return tcpip.AddrFrom16(destination.As16())
} else {
return tcpip.AddrFrom4(destination.As4())
}
}
func AddrFromAddress(address tcpip.Address) netip.Addr {
if address.Len() == 16 {
return netip.AddrFrom16(address.As16())
} else {
return netip.AddrFrom4(address.As4())
}
}

72
gvisor_err.go Normal file
View File

@@ -0,0 +1,72 @@
//go:build with_gvisor
package tun
import (
"net"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
E "github.com/sagernet/sing/common/exceptions"
)
type gTCPConn struct {
*gonet.TCPConn
}
func (c *gTCPConn) Upstream() any {
return c.TCPConn
}
func (c *gTCPConn) Write(b []byte) (n int, err error) {
n, err = c.TCPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
type gUDPConn struct {
*gonet.UDPConn
}
func (c *gUDPConn) Read(b []byte) (n int, err error) {
n, err = c.UDPConn.Read(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func (c *gUDPConn) Write(b []byte) (n int, err error) {
n, err = c.UDPConn.Write(b)
if err == nil {
return
}
err = wrapError(err)
return
}
func wrapStackError(err tcpip.Error) error {
switch err.(type) {
case *tcpip.ErrClosedForSend,
*tcpip.ErrClosedForReceive,
*tcpip.ErrAborted:
return net.ErrClosed
}
return E.New(err.String())
}
func wrapError(err error) error {
if opErr, isOpErr := err.(*net.OpError); isOpErr {
switch opErr.Err.Error() {
case "endpoint is closed for send",
"endpoint is closed for receive",
"operation aborted":
return net.ErrClosed
}
}
return err
}

View File

@@ -13,9 +13,3 @@ func NewGVisor(
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}
func NewMixed(
options StackOptions,
) (Stack, error) {
return nil, ErrGVisorNotIncluded
}

125
gvisor_udp.go Normal file
View File

@@ -0,0 +1,125 @@
//go:build with_gvisor
package tun
import (
"context"
"math"
"net/netip"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
udpNat *udpnat.Service[netip.AddrPort]
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, udpTimeout int64) *UDPForwarder {
return &UDPForwarder{
ctx: ctx,
stack: stack,
udpNat: udpnat.New[netip.AddrPort](udpTimeout, handler),
}
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
upstreamMetadata.Destination = M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
var netProto tcpip.NetworkProtocolNumber
if upstreamMetadata.Source.IsIPv4() {
netProto = header.IPv4ProtocolNumber
} else {
netProto = header.IPv6ProtocolNumber
}
f.udpNat.NewPacket(
f.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(pkt.Data().AsRange().ToSlice()),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &UDPBackWriter{f.stack, id.RemoteAddress, id.RemotePort, netProto}
},
)
return true
}
type UDPBackWriter struct {
stack *stack.Stack
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) {
return E.New("send IPv6 packet to IPv4 connection")
}
defer packetBuffer.Release()
route, err := w.stack.FindRoute(
defaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return wrapStackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(checksum.Combine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return wrapStackError(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}

View File

@@ -1,33 +0,0 @@
package checksum_test
import (
"crypto/rand"
"testing"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/tschecksum"
)
func BenchmarkTsChecksum(b *testing.B) {
packet := make([][]byte, 1000)
for i := 0; i < 1000; i++ {
packet[i] = make([]byte, 1500)
rand.Read(packet[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
tschecksum.Checksum(packet[i%1000], 0)
}
}
func BenchmarkGChecksum(b *testing.B) {
packet := make([][]byte, 1000)
for i := 0; i < 1000; i++ {
packet[i] = make([]byte, 1500)
rand.Read(packet[i])
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
checksum.ChecksumDefault(packet[i%1000], 0)
}
}

View File

@@ -0,0 +1,40 @@
package clashtcpip
import (
"encoding/binary"
)
type ICMPType = byte
const (
ICMPTypePingRequest byte = 0x8
ICMPTypePingResponse byte = 0x0
)
type ICMPPacket []byte
func (p ICMPPacket) Type() ICMPType {
return p[0]
}
func (p ICMPPacket) SetType(v ICMPType) {
p[0] = v
}
func (p ICMPPacket) Code() byte {
return p[1]
}
func (p ICMPPacket) Checksum() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p ICMPPacket) SetChecksum(sum [2]byte) {
p[2] = sum[0]
p[3] = sum[1]
}
func (p ICMPPacket) ResetChecksum() {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(0, p))
}

View File

@@ -0,0 +1,172 @@
package clashtcpip
import (
"encoding/binary"
)
type ICMPv6Packet []byte
const (
ICMPv6HeaderSize = 4
ICMPv6MinimumSize = 8
ICMPv6PayloadOffset = 8
ICMPv6EchoMinimumSize = 8
ICMPv6ErrorHeaderSize = 8
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
ICMPv6ChecksumOffset = 2
icmpv6PointerOffset = 4
icmpv6MTUOffset = 4
icmpv6IdentOffset = 4
icmpv6SequenceOffset = 6
NDPHopLimit = 255
)
type ICMPv6Type byte
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
ICMPv6TimeExceeded ICMPv6Type = 3
ICMPv6ParamProblem ICMPv6Type = 4
ICMPv6EchoRequest ICMPv6Type = 128
ICMPv6EchoReply ICMPv6Type = 129
ICMPv6RouterSolicit ICMPv6Type = 133
ICMPv6RouterAdvert ICMPv6Type = 134
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
ICMPv6MulticastListenerQuery ICMPv6Type = 130
ICMPv6MulticastListenerReport ICMPv6Type = 131
ICMPv6MulticastListenerDone ICMPv6Type = 132
)
func (typ ICMPv6Type) IsErrorType() bool {
return typ&0x80 == 0
}
type ICMPv6Code byte
const (
ICMPv6NetworkUnreachable ICMPv6Code = 0
ICMPv6Prohibited ICMPv6Code = 1
ICMPv6BeyondScope ICMPv6Code = 2
ICMPv6AddressUnreachable ICMPv6Code = 3
ICMPv6PortUnreachable ICMPv6Code = 4
ICMPv6Policy ICMPv6Code = 5
ICMPv6RejectRoute ICMPv6Code = 6
)
const (
ICMPv6HopLimitExceeded ICMPv6Code = 0
ICMPv6ReassemblyTimeout ICMPv6Code = 1
)
const (
ICMPv6ErroneousHeader ICMPv6Code = 0
ICMPv6UnknownHeader ICMPv6Code = 1
ICMPv6UnknownOption ICMPv6Code = 2
)
const ICMPv6UnusedCode ICMPv6Code = 0
func (b ICMPv6Packet) Type() ICMPv6Type {
return ICMPv6Type(b[0])
}
func (b ICMPv6Packet) SetType(t ICMPv6Type) {
b[0] = byte(t)
}
func (b ICMPv6Packet) Code() ICMPv6Code {
return ICMPv6Code(b[1])
}
func (b ICMPv6Packet) SetCode(c ICMPv6Code) {
b[1] = byte(c)
}
func (b ICMPv6Packet) TypeSpecific() uint32 {
return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
}
func (b ICMPv6Packet) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
func (b ICMPv6Packet) Checksum() uint16 {
return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:])
}
func (b ICMPv6Packet) SetChecksum(sum [2]byte) {
_ = b[ICMPv6ChecksumOffset+1]
b[ICMPv6ChecksumOffset] = sum[0]
b[ICMPv6ChecksumOffset+1] = sum[1]
}
func (ICMPv6Packet) SourcePort() uint16 {
return 0
}
func (ICMPv6Packet) DestinationPort() uint16 {
return 0
}
func (ICMPv6Packet) SetSourcePort(uint16) {
}
func (ICMPv6Packet) SetDestinationPort(uint16) {
}
func (b ICMPv6Packet) MTU() uint32 {
return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
}
func (b ICMPv6Packet) SetMTU(mtu uint32) {
binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
}
func (b ICMPv6Packet) Ident() uint16 {
return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
}
func (b ICMPv6Packet) SetIdent(ident uint16) {
binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
}
func (b ICMPv6Packet) Sequence() uint16 {
return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
}
func (b ICMPv6Packet) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
func (b ICMPv6Packet) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
func (b ICMPv6Packet) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
func (b ICMPv6Packet) ResetChecksum(psum uint32) {
b.SetChecksum(zeroChecksum)
b.SetChecksum(Checksum(psum, b))
}

209
internal/clashtcpip/ip.go Normal file
View File

@@ -0,0 +1,209 @@
package clashtcpip
import (
"encoding/binary"
"errors"
"net/netip"
)
type IPProtocol = byte
type IP interface {
Payload() []byte
SourceIP() netip.Addr
DestinationIP() netip.Addr
SetSourceIP(ip netip.Addr)
SetDestinationIP(ip netip.Addr)
Protocol() IPProtocol
DecTimeToLive()
ResetChecksum()
PseudoSum() uint32
}
// IPProtocol type
const (
ICMP IPProtocol = 0x01
TCP IPProtocol = 0x06
UDP IPProtocol = 0x11
ICMPv6 IPProtocol = 0x3a
)
const (
FlagDontFragment = 1 << 1
FlagMoreFragment = 1 << 2
)
const (
IPv4HeaderSize = 20
IPv4Version = 4
IPv4OptionsOffset = 20
IPv4PacketMinLength = IPv4OptionsOffset
)
var (
ErrInvalidLength = errors.New("invalid packet length")
ErrInvalidIPVersion = errors.New("invalid ip version")
ErrInvalidChecksum = errors.New("invalid checksum")
)
type IPv4Packet []byte
func (p IPv4Packet) TotalLen() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p IPv4Packet) SetTotalLength(length uint16) {
binary.BigEndian.PutUint16(p[2:], length)
}
func (p IPv4Packet) HeaderLen() uint16 {
return uint16(p[0]&0xf) * 4
}
func (p IPv4Packet) SetHeaderLen(length uint16) {
p[0] &= 0xF0
p[0] |= byte(length / 4)
}
func (p IPv4Packet) TypeOfService() byte {
return p[1]
}
func (p IPv4Packet) SetTypeOfService(tos byte) {
p[1] = tos
}
func (p IPv4Packet) Identification() uint16 {
return binary.BigEndian.Uint16(p[4:])
}
func (p IPv4Packet) SetIdentification(id uint16) {
binary.BigEndian.PutUint16(p[4:], id)
}
func (p IPv4Packet) FragmentOffset() uint16 {
return binary.BigEndian.Uint16([]byte{p[6] & 0x7, p[7]}) * 8
}
func (p IPv4Packet) SetFragmentOffset(offset uint32) {
flags := p.Flags()
binary.BigEndian.PutUint16(p[6:], uint16(offset/8))
p.SetFlags(flags)
}
func (p IPv4Packet) DataLen() uint16 {
return p.TotalLen() - p.HeaderLen()
}
func (p IPv4Packet) Payload() []byte {
return p[p.HeaderLen():p.TotalLen()]
}
func (p IPv4Packet) Protocol() IPProtocol {
return p[9]
}
func (p IPv4Packet) SetProtocol(protocol IPProtocol) {
p[9] = protocol
}
func (p IPv4Packet) Flags() byte {
return p[6] >> 5
}
func (p IPv4Packet) SetFlags(flags byte) {
p[6] &= 0x1F
p[6] |= flags << 5
}
func (p IPv4Packet) SourceIP() netip.Addr {
return netip.AddrFrom4([4]byte{p[12], p[13], p[14], p[15]})
}
func (p IPv4Packet) SetSourceIP(ip netip.Addr) {
if ip.Is4() {
copy(p[12:16], ip.AsSlice())
}
}
func (p IPv4Packet) DestinationIP() netip.Addr {
return netip.AddrFrom4([4]byte{p[16], p[17], p[18], p[19]})
}
func (p IPv4Packet) SetDestinationIP(ip netip.Addr) {
if ip.Is4() {
copy(p[16:20], ip.AsSlice())
}
}
func (p IPv4Packet) Checksum() uint16 {
return binary.BigEndian.Uint16(p[10:])
}
func (p IPv4Packet) SetChecksum(sum [2]byte) {
p[10] = sum[0]
p[11] = sum[1]
}
func (p IPv4Packet) TimeToLive() uint8 {
return p[8]
}
func (p IPv4Packet) SetTimeToLive(ttl uint8) {
p[8] = ttl
}
func (p IPv4Packet) DecTimeToLive() {
p[8] = p[8] - uint8(1)
}
func (p IPv4Packet) ResetChecksum() {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(0, p[:p.HeaderLen()]))
}
// PseudoSum for tcp checksum
func (p IPv4Packet) PseudoSum() uint32 {
sum := Sum(p[12:20])
sum += uint32(p.Protocol())
sum += uint32(p.DataLen())
return sum
}
func (p IPv4Packet) Valid() bool {
return len(p) >= IPv4HeaderSize && p.TotalLen() >= p.HeaderLen() && uint16(len(p)) >= p.TotalLen()
}
func (p IPv4Packet) Verify() error {
if len(p) < IPv4PacketMinLength {
return ErrInvalidLength
}
checksum := []byte{p[10], p[11]}
headerLength := uint16(p[0]&0xF) * 4
packetLength := binary.BigEndian.Uint16(p[2:])
if p[0]>>4 != 4 {
return ErrInvalidIPVersion
}
if uint16(len(p)) < packetLength || packetLength < headerLength {
return ErrInvalidLength
}
p[10] = 0
p[11] = 0
defer copy(p[10:12], checksum)
answer := Checksum(0, p[:headerLength])
if answer[0] != checksum[0] || answer[1] != checksum[1] {
return ErrInvalidChecksum
}
return nil
}
var _ IP = (*IPv4Packet)(nil)

141
internal/clashtcpip/ipv6.go Normal file
View File

@@ -0,0 +1,141 @@
package clashtcpip
import (
"encoding/binary"
"net/netip"
)
const (
versTCFL = 0
IPv6PayloadLenOffset = 4
IPv6NextHeaderOffset = 6
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
const (
versIHL = 0
tos = 1
ipVersionShift = 4
ipIHLMask = 0x0f
IPv4IHLStride = 4
)
type IPv6Packet []byte
const (
IPv6MinimumSize = IPv6FixedHeaderSize
IPv6AddressSize = 16
IPv6Version = 6
IPv6MinimumMTU = 1280
)
func (b IPv6Packet) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
func (b IPv6Packet) HopLimit() uint8 {
return b[hopLimit]
}
func (b IPv6Packet) NextHeader() byte {
return b[IPv6NextHeaderOffset]
}
func (b IPv6Packet) Protocol() IPProtocol {
return b.NextHeader()
}
func (b IPv6Packet) Payload() []byte {
return b[IPv6MinimumSize:][:b.PayloadLength()]
}
func (b IPv6Packet) SourceIP() netip.Addr {
addr, _ := netip.AddrFromSlice(b[v6SrcAddr:][:IPv6AddressSize])
return addr
}
func (b IPv6Packet) DestinationIP() netip.Addr {
addr, _ := netip.AddrFromSlice(b[v6DstAddr:][:IPv6AddressSize])
return addr
}
func (IPv6Packet) Checksum() uint16 {
return 0
}
func (b IPv6Packet) TOS() (uint8, uint32) {
v := binary.BigEndian.Uint32(b[versTCFL:])
return uint8(v >> 20), v & 0xfffff
}
func (b IPv6Packet) SetTOS(t uint8, l uint32) {
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
}
func (b IPv6Packet) SetPayloadLength(payloadLength uint16) {
binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
func (b IPv6Packet) SetSourceIP(addr netip.Addr) {
if addr.Is6() {
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
}
}
func (b IPv6Packet) SetDestinationIP(addr netip.Addr) {
if addr.Is6() {
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
}
}
func (b IPv6Packet) SetHopLimit(v uint8) {
b[hopLimit] = v
}
func (b IPv6Packet) SetNextHeader(v byte) {
b[IPv6NextHeaderOffset] = v
}
func (b IPv6Packet) SetProtocol(p IPProtocol) {
b.SetNextHeader(p)
}
func (b IPv6Packet) DecTimeToLive() {
b[hopLimit] = b[hopLimit] - uint8(1)
}
func (IPv6Packet) SetChecksum(uint16) {
}
func (IPv6Packet) ResetChecksum() {
}
func (b IPv6Packet) PseudoSum() uint32 {
sum := Sum(b[v6SrcAddr:IPv6FixedHeaderSize])
sum += uint32(b.Protocol())
sum += uint32(b.PayloadLength())
return sum
}
func (b IPv6Packet) Valid() bool {
return len(b) >= IPv6MinimumSize && len(b) >= int(b.PayloadLength())+IPv6MinimumSize
}
func IPVersion(b []byte) int {
if len(b) < versIHL+1 {
return -1
}
return int(b[versIHL] >> ipVersionShift)
}
var _ IP = (*IPv6Packet)(nil)

View File

@@ -0,0 +1,90 @@
package clashtcpip
import (
"encoding/binary"
"net"
)
const (
TCPFin uint16 = 1 << 0
TCPSyn uint16 = 1 << 1
TCPRst uint16 = 1 << 2
TCPPuh uint16 = 1 << 3
TCPAck uint16 = 1 << 4
TCPUrg uint16 = 1 << 5
TCPEce uint16 = 1 << 6
TCPEwr uint16 = 1 << 7
TCPNs uint16 = 1 << 8
)
const TCPHeaderSize = 20
type TCPPacket []byte
func (p TCPPacket) SourcePort() uint16 {
return binary.BigEndian.Uint16(p)
}
func (p TCPPacket) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(p, port)
}
func (p TCPPacket) DestinationPort() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p TCPPacket) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(p[2:], port)
}
func (p TCPPacket) Flags() uint16 {
return uint16(p[13] | (p[12] & 0x1))
}
func (p TCPPacket) Checksum() uint16 {
return binary.BigEndian.Uint16(p[16:])
}
func (p TCPPacket) SetChecksum(sum [2]byte) {
p[16] = sum[0]
p[17] = sum[1]
}
func (p TCPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))
}
func (p TCPPacket) Valid() bool {
return len(p) >= TCPHeaderSize
}
func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error {
var checksum [2]byte
checksum[0] = p[16]
checksum[1] = p[17]
// reset checksum
p[16] = 0
p[17] = 0
// restore checksum
defer func() {
p[16] = checksum[0]
p[17] = checksum[1]
}()
// check checksum
s := uint32(0)
s += Sum(sourceAddress)
s += Sum(targetAddress)
s += uint32(TCP)
s += uint32(len(p))
check := Checksum(s, p)
if checksum[0] != check[0] || checksum[1] != check[1] {
return ErrInvalidChecksum
}
return nil
}

View File

@@ -0,0 +1,24 @@
package clashtcpip
var zeroChecksum = [2]byte{0x00, 0x00}
var SumFnc = SumCompat
func Sum(b []byte) uint32 {
return SumFnc(b)
}
// Checksum for Internet Protocol family headers
func Checksum(sum uint32, b []byte) (answer [2]byte) {
sum += Sum(b)
sum = (sum >> 16) + (sum & 0xffff)
sum += sum >> 16
sum = ^sum
answer[0] = byte(sum >> 8)
answer[1] = byte(sum)
return
}
func SetIPv4(packet []byte) {
packet[0] = (packet[0] & 0x0f) | (4 << 4)
}

View File

@@ -0,0 +1,26 @@
//go:build !noasm
package clashtcpip
import (
"unsafe"
"golang.org/x/sys/cpu"
)
//go:noescape
func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr
func SumAVX2(data []byte) uint32 {
if len(data) == 0 {
return 0
}
return uint32(sumAsmAvx2(unsafe.Pointer(&data[0]), uintptr(len(data))))
}
func init() {
if cpu.X86.HasAVX2 {
SumFnc = SumAVX2
}
}

View File

@@ -0,0 +1,140 @@
#include "textflag.h"
DATA endian_swap_mask<>+0(SB)/8, $0x607040502030001
DATA endian_swap_mask<>+8(SB)/8, $0xE0F0C0D0A0B0809
DATA endian_swap_mask<>+16(SB)/8, $0x607040502030001
DATA endian_swap_mask<>+24(SB)/8, $0xE0F0C0D0A0B0809
GLOBL endian_swap_mask<>(SB), RODATA, $32
// func sumAsmAvx2(data unsafe.Pointer, length uintptr) uintptr
//
// args (8 bytes aligned):
// data unsafe.Pointer - 8 bytes - 0 offset
// length uintptr - 8 bytes - 8 offset
// result uintptr - 8 bytes - 16 offset
#define PDATA AX
#define LENGTH CX
#define RESULT BX
TEXT ·sumAsmAvx2(SB),NOSPLIT,$0-24
MOVQ data+0(FP), PDATA
MOVQ length+8(FP), LENGTH
XORQ RESULT, RESULT
#define VSUM Y0
#define ENDIAN_SWAP_MASK Y1
BEGIN:
VMOVDQU endian_swap_mask<>(SB), ENDIAN_SWAP_MASK
VPXOR VSUM, VSUM, VSUM
#define LOADED_0 Y2
#define LOADED_1 Y3
#define LOADED_2 Y4
#define LOADED_3 Y5
BATCH_64:
CMPQ LENGTH, $64
JB BATCH_32
VPMOVZXWD (PDATA), LOADED_0
VPMOVZXWD 16(PDATA), LOADED_1
VPMOVZXWD 32(PDATA), LOADED_2
VPMOVZXWD 48(PDATA), LOADED_3
VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0
VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1
VPSHUFB ENDIAN_SWAP_MASK, LOADED_2, LOADED_2
VPSHUFB ENDIAN_SWAP_MASK, LOADED_3, LOADED_3
VPADDD LOADED_0, VSUM, VSUM
VPADDD LOADED_1, VSUM, VSUM
VPADDD LOADED_2, VSUM, VSUM
VPADDD LOADED_3, VSUM, VSUM
ADDQ $-64, LENGTH
ADDQ $64, PDATA
JMP BATCH_64
#undef LOADED_0
#undef LOADED_1
#undef LOADED_2
#undef LOADED_3
#define LOADED_0 Y2
#define LOADED_1 Y3
BATCH_32:
CMPQ LENGTH, $32
JB BATCH_16
VPMOVZXWD (PDATA), LOADED_0
VPMOVZXWD 16(PDATA), LOADED_1
VPSHUFB ENDIAN_SWAP_MASK, LOADED_0, LOADED_0
VPSHUFB ENDIAN_SWAP_MASK, LOADED_1, LOADED_1
VPADDD LOADED_0, VSUM, VSUM
VPADDD LOADED_1, VSUM, VSUM
ADDQ $-32, LENGTH
ADDQ $32, PDATA
JMP BATCH_32
#undef LOADED_0
#undef LOADED_1
#define LOADED Y2
BATCH_16:
CMPQ LENGTH, $16
JB COLLECT
VPMOVZXWD (PDATA), LOADED
VPSHUFB ENDIAN_SWAP_MASK, LOADED, LOADED
VPADDD LOADED, VSUM, VSUM
ADDQ $-16, LENGTH
ADDQ $16, PDATA
JMP BATCH_16
#undef LOADED
#define EXTRACTED Y2
#define EXTRACTED_128 X2
#define TEMP_64 DX
COLLECT:
VEXTRACTI128 $0, VSUM, EXTRACTED_128
VPEXTRD $0, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $1, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $2, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $3, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VEXTRACTI128 $1, VSUM, EXTRACTED_128
VPEXTRD $0, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $1, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $2, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
VPEXTRD $3, EXTRACTED_128, TEMP_64
ADDL TEMP_64, RESULT
#undef EXTRACTED
#undef EXTRACTED_128
#undef TEMP_64
#define TEMP DX
#define TEMP2 SI
BATCH_2:
CMPQ LENGTH, $2
JB BATCH_1
XORQ TEMP, TEMP
MOVW (PDATA), TEMP
MOVQ TEMP, TEMP2
SHRW $8, TEMP2
SHLW $8, TEMP
ORW TEMP2, TEMP
ADDL TEMP, RESULT
ADDQ $-2, LENGTH
ADDQ $2, PDATA
JMP BATCH_2
#undef TEMP
#define TEMP DX
BATCH_1:
CMPQ LENGTH, $0
JZ RETURN
XORQ TEMP, TEMP
MOVB (PDATA), TEMP
SHLW $8, TEMP
ADDL TEMP, RESULT
#undef TEMP
RETURN:
MOVQ RESULT, result+16(FP)
RET

View File

@@ -0,0 +1,51 @@
package clashtcpip
import (
"crypto/rand"
"testing"
"golang.org/x/sys/cpu"
)
func Test_SumAVX2(t *testing.T) {
if !cpu.X86.HasAVX2 {
t.Skipf("AVX2 unavailable")
}
bytes := make([]byte, chunkSize)
for size := 0; size <= chunkSize; size++ {
for count := 0; count < chunkCount; count++ {
_, err := rand.Reader.Read(bytes[:size])
if err != nil {
t.Skipf("Rand read failed: %v", err)
}
compat := SumCompat(bytes[:size])
avx := SumAVX2(bytes[:size])
if compat != avx {
t.Errorf("Sum of length=%d mismatched", size)
}
}
}
}
func Benchmark_SumAVX2(b *testing.B) {
if !cpu.X86.HasAVX2 {
b.Skipf("AVX2 unavailable")
}
bytes := make([]byte, chunkSize)
_, err := rand.Reader.Read(bytes)
if err != nil {
b.Skipf("Rand read failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
SumAVX2(bytes)
}
}

View File

@@ -0,0 +1,24 @@
package clashtcpip
import (
"unsafe"
"golang.org/x/sys/cpu"
)
//go:noescape
func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr
func SumNeon(data []byte) uint32 {
if len(data) == 0 {
return 0
}
return uint32(sumAsmNeon(unsafe.Pointer(&data[0]), uintptr(len(data))))
}
func init() {
if cpu.ARM64.HasASIMD {
SumFnc = SumNeon
}
}

View File

@@ -0,0 +1,118 @@
#include "textflag.h"
// func sumAsmNeon(data unsafe.Pointer, length uintptr) uintptr
//
// args (8 bytes aligned):
// data unsafe.Pointer - 8 bytes - 0 offset
// length uintptr - 8 bytes - 8 offset
// result uintptr - 8 bytes - 16 offset
#define PDATA R0
#define LENGTH R1
#define RESULT R2
#define VSUM V0
TEXT ·sumAsmNeon(SB),NOSPLIT,$0-24
MOVD data+0(FP), PDATA
MOVD length+8(FP), LENGTH
MOVD $0, RESULT
VMOVQ $0, $0, VSUM
#define LOADED_0 V1
#define LOADED_1 V2
#define LOADED_2 V3
#define LOADED_3 V4
BATCH_32:
CMP $32, LENGTH
BLO BATCH_16
VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8, LOADED_2.B8, LOADED_3.B8]
VREV16 LOADED_0.B8, LOADED_0.B8
VREV16 LOADED_1.B8, LOADED_1.B8
VREV16 LOADED_2.B8, LOADED_2.B8
VREV16 LOADED_3.B8, LOADED_3.B8
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
VUSHLL $0, LOADED_1.H4, LOADED_1.S4
VUSHLL $0, LOADED_2.H4, LOADED_2.S4
VUSHLL $0, LOADED_3.H4, LOADED_3.S4
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
VADD LOADED_1.S4, VSUM.S4, VSUM.S4
VADD LOADED_2.S4, VSUM.S4, VSUM.S4
VADD LOADED_3.S4, VSUM.S4, VSUM.S4
ADD $-32, LENGTH
ADD $32, PDATA
B BATCH_32
#undef LOADED_0
#undef LOADED_1
#undef LOADED_2
#undef LOADED_3
#define LOADED_0 V1
#define LOADED_1 V2
BATCH_16:
CMP $16, LENGTH
BLO BATCH_8
VLD1 (PDATA), [LOADED_0.B8, LOADED_1.B8]
VREV16 LOADED_0.B8, LOADED_0.B8
VREV16 LOADED_1.B8, LOADED_1.B8
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
VUSHLL $0, LOADED_1.H4, LOADED_1.S4
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
VADD LOADED_1.S4, VSUM.S4, VSUM.S4
ADD $-16, LENGTH
ADD $16, PDATA
B BATCH_16
#undef LOADED_0
#undef LOADED_1
#define LOADED_0 V1
BATCH_8:
CMP $8, LENGTH
BLO BATCH_2
VLD1 (PDATA), [LOADED_0.B8]
VREV16 LOADED_0.B8, LOADED_0.B8
VUSHLL $0, LOADED_0.H4, LOADED_0.S4
VADD LOADED_0.S4, VSUM.S4, VSUM.S4
ADD $-8, LENGTH
ADD $8, PDATA
B BATCH_8
#undef LOADED_0
#define LOADED_L R3
#define LOADED_H R4
BATCH_2:
CMP $2, LENGTH
BLO BATCH_1
MOVBU (PDATA), LOADED_H
MOVBU 1(PDATA), LOADED_L
LSL $8, LOADED_H
ORR LOADED_H, LOADED_L, LOADED_L
ADD LOADED_L, RESULT, RESULT
ADD $2, PDATA
ADD $-2, LENGTH
B BATCH_2
#undef LOADED_H
#undef LOADED_L
#define LOADED R3
BATCH_1:
CMP $1, LENGTH
BLO COLLECT
MOVBU (PDATA), LOADED
LSL $8, LOADED
ADD LOADED, RESULT, RESULT
#define EXTRACTED R3
COLLECT:
VMOV VSUM.S[0], EXTRACTED
ADD EXTRACTED, RESULT
VMOV VSUM.S[1], EXTRACTED
ADD EXTRACTED, RESULT
VMOV VSUM.S[2], EXTRACTED
ADD EXTRACTED, RESULT
VMOV VSUM.S[3], EXTRACTED
ADD EXTRACTED, RESULT
#undef VSUM
#undef PDATA
#undef LENGTH
RETURN:
MOVD RESULT, result+16(FP)
RET

View File

@@ -0,0 +1,51 @@
package clashtcpip
import (
"crypto/rand"
"testing"
"golang.org/x/sys/cpu"
)
func Test_SumNeon(t *testing.T) {
if !cpu.ARM64.HasASIMD {
t.Skipf("Neon unavailable")
}
bytes := make([]byte, chunkSize)
for size := 0; size <= chunkSize; size++ {
for count := 0; count < chunkCount; count++ {
_, err := rand.Reader.Read(bytes[:size])
if err != nil {
t.Skipf("Rand read failed: %v", err)
}
compat := SumCompat(bytes[:size])
neon := SumNeon(bytes[:size])
if compat != neon {
t.Errorf("Sum of length=%d mismatched", size)
}
}
}
}
func Benchmark_SumNeon(b *testing.B) {
if !cpu.ARM64.HasASIMD {
b.Skipf("Neon unavailable")
}
bytes := make([]byte, chunkSize)
_, err := rand.Reader.Read(bytes)
if err != nil {
b.Skipf("Rand read failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
SumNeon(bytes)
}
}

View File

@@ -0,0 +1,14 @@
package clashtcpip
func SumCompat(b []byte) (sum uint32) {
n := len(b)
if n&1 != 0 {
n--
sum += uint32(b[n]) << 8
}
for i := 0; i < n; i += 2 {
sum += (uint32(b[i]) << 8) | uint32(b[i+1])
}
return
}

View File

@@ -0,0 +1,26 @@
package clashtcpip
import (
"crypto/rand"
"testing"
)
const (
chunkSize = 9000
chunkCount = 10
)
func Benchmark_SumCompat(b *testing.B) {
bytes := make([]byte, chunkSize)
_, err := rand.Reader.Read(bytes)
if err != nil {
b.Skipf("Rand read failed: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
SumCompat(bytes)
}
}

View File

@@ -0,0 +1,55 @@
package clashtcpip
import (
"encoding/binary"
)
const UDPHeaderSize = 8
type UDPPacket []byte
func (p UDPPacket) Length() uint16 {
return binary.BigEndian.Uint16(p[4:])
}
func (p UDPPacket) SetLength(length uint16) {
binary.BigEndian.PutUint16(p[4:], length)
}
func (p UDPPacket) SourcePort() uint16 {
return binary.BigEndian.Uint16(p)
}
func (p UDPPacket) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(p, port)
}
func (p UDPPacket) DestinationPort() uint16 {
return binary.BigEndian.Uint16(p[2:])
}
func (p UDPPacket) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(p[2:], port)
}
func (p UDPPacket) Payload() []byte {
return p[UDPHeaderSize:p.Length()]
}
func (p UDPPacket) Checksum() uint16 {
return binary.BigEndian.Uint16(p[6:])
}
func (p UDPPacket) SetChecksum(sum [2]byte) {
p[6] = sum[0]
p[7] = sum[1]
}
func (p UDPPacket) ResetChecksum(psum uint32) {
p.SetChecksum(zeroChecksum)
p.SetChecksum(Checksum(psum, p))
}
func (p UDPPacket) Valid() bool {
return len(p) >= UDPHeaderSize && uint16(len(p)) >= p.Length()
}

View File

@@ -1,4 +0,0 @@
# gtcpip
Minimal tcpip package kanged from gvisor
Version 20241007.0

View File

@@ -1,45 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package checksum provides the implementation of the encoding and decoding of
// network protocol headers.
package checksum
import (
"encoding/binary"
)
// Size is the size of a checksum.
//
// The checksum is held in a uint16 which is 2 bytes.
const Size = 2
// Put puts the checksum in the provided byte slice.
func Put(b []byte, xsum uint16) {
binary.BigEndian.PutUint16(b, xsum)
}
// Combine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
//
// Note that checksum a must have been computed on an even number of bytes.
func Combine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}
func ChecksumDefault(buf []byte, initial uint16) uint16 {
s, _ := calculateChecksum(buf, false, initial)
return s
}

View File

@@ -1,12 +0,0 @@
//go:build !amd64
package checksum
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array. This function uses an optimized version of the checksum
// algorithm.
//
// The initial checksum must have been computed on an even number of bytes.
func Checksum(buf []byte, initial uint16) uint16 {
return ChecksumDefault(buf, initial)
}

View File

@@ -1,9 +0,0 @@
//go:build amd64
package checksum
import "github.com/sagernet/sing-tun/internal/tschecksum"
func Checksum(buf []byte, initial uint16) uint16 {
return tschecksum.Checksum(buf, initial)
}

View File

@@ -1,182 +0,0 @@
// Copyright 2023 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package checksum
import (
"encoding/binary"
"math/bits"
"unsafe"
)
// Note: odd indicates whether initial is a partial checksum over an odd number
// of bytes.
func calculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
// Use a larger-than-uint16 accumulator to benefit from parallel summation
// as described in RFC 1071 1.2.C.
acc := uint64(initial)
// Handle an odd number of previously-summed bytes, and get the return
// value for odd.
if odd {
acc += uint64(buf[0])
buf = buf[1:]
}
odd = len(buf)&1 != 0
// Aligning &buf[0] below is much simpler if len(buf) >= 8; special-case
// smaller bufs.
if len(buf) < 8 {
if len(buf) >= 4 {
acc += (uint64(buf[0]) << 8) + uint64(buf[1])
acc += (uint64(buf[2]) << 8) + uint64(buf[3])
buf = buf[4:]
}
if len(buf) >= 2 {
acc += (uint64(buf[0]) << 8) + uint64(buf[1])
buf = buf[2:]
}
if len(buf) >= 1 {
acc += uint64(buf[0]) << 8
// buf = buf[1:] is skipped because it's unused and nogo will
// complain.
}
return reduce(acc), odd
}
// On little-endian architectures, multi-byte loads from buf will load
// bytes in the wrong order. Rather than byte-swap after each load (slow),
// we byte-swap the accumulator before summing any bytes and byte-swap it
// back before returning, which still produces the correct result as
// described in RFC 1071 1.2.B "Byte Order Independence".
//
// acc is at most a uint16 + a uint8, so its upper 32 bits must be 0s. We
// preserve this property by byte-swapping only the lower 32 bits of acc,
// so that additions to acc performed during alignment can't overflow.
acc = uint64(bswapIfLittleEndian32(uint32(acc)))
// Align &buf[0] to an 8-byte boundary.
bswapped := false
if sliceAddr(buf)&1 != 0 {
// Compute the rest of the partial checksum with bytes swapped, and
// swap back before returning; see the last paragraph of
// RFC 1071 1.2.B.
acc = uint64(bits.ReverseBytes32(uint32(acc)))
bswapped = true
// No `<< 8` here due to the byte swap we just did.
acc += uint64(bswapIfLittleEndian16(uint16(buf[0])))
buf = buf[1:]
}
if sliceAddr(buf)&2 != 0 {
acc += uint64(*(*uint16)(unsafe.Pointer(&buf[0])))
buf = buf[2:]
}
if sliceAddr(buf)&4 != 0 {
acc += uint64(*(*uint32)(unsafe.Pointer(&buf[0])))
buf = buf[4:]
}
// Sum 64 bytes at a time. Beyond this point, additions to acc may
// overflow, so we have to handle carrying.
for len(buf) >= 64 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[32])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[40])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[48])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[56])), carry)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[64:]
}
// Sum the remaining 0-63 bytes.
if len(buf) >= 32 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[16])), carry)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[24])), carry)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[32:]
}
if len(buf) >= 16 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[8])), carry)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[16:]
}
if len(buf) >= 8 {
var carry uint64
acc, carry = bits.Add64(acc, *(*uint64)(unsafe.Pointer(&buf[0])), 0)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[8:]
}
if len(buf) >= 4 {
var carry uint64
acc, carry = bits.Add64(acc, uint64(*(*uint32)(unsafe.Pointer(&buf[0]))), 0)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[4:]
}
if len(buf) >= 2 {
var carry uint64
acc, carry = bits.Add64(acc, uint64(*(*uint16)(unsafe.Pointer(&buf[0]))), 0)
acc, _ = bits.Add64(acc, 0, carry)
buf = buf[2:]
}
if len(buf) >= 1 {
// bswapIfBigEndian16(buf[0]) == bswapIfLittleEndian16(buf[0]<<8).
var carry uint64
acc, carry = bits.Add64(acc, uint64(bswapIfBigEndian16(uint16(buf[0]))), 0)
acc, _ = bits.Add64(acc, 0, carry)
// buf = buf[1:] is skipped because it's unused and nogo will complain.
}
// Reduce the checksum to 16 bits and undo byte swaps before returning.
acc16 := bswapIfLittleEndian16(reduce(acc))
if bswapped {
acc16 = bits.ReverseBytes16(acc16)
}
return acc16, odd
}
func reduce(acc uint64) uint16 {
// Ideally we would do:
// return uint16(acc>>48) +' uint16(acc>>32) +' uint16(acc>>16) +' uint16(acc)
// for more instruction-level parallelism; however, there is no
// bits.Add16().
acc = (acc >> 32) + (acc & 0xffff_ffff) // at most 0x1_ffff_fffe
acc32 := uint32(acc>>32 + acc) // at most 0xffff_ffff
acc32 = (acc32 >> 16) + (acc32 & 0xffff) // at most 0x1_fffe
return uint16(acc32>>16 + acc32) // at most 0xffff
}
func bswapIfLittleEndian32(val uint32) uint32 {
return binary.BigEndian.Uint32((*[4]byte)(unsafe.Pointer(&val))[:])
}
func bswapIfLittleEndian16(val uint16) uint16 {
return binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
}
func bswapIfBigEndian16(val uint16) uint16 {
return binary.LittleEndian.Uint16((*[2]byte)(unsafe.Pointer(&val))[:])
}
func sliceAddr(buf []byte) uintptr {
return uintptr(unsafe.Pointer(unsafe.SliceData(buf)))
}

View File

@@ -1,46 +0,0 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcpip
import (
"fmt"
)
// Error represents an error in the netstack error space.
//
// The error interface is intentionally omitted to avoid loss of type
// information that would occur if these errors were passed as error.
type Error interface {
isError()
// IgnoreStats indicates whether this error should be included in failure
// counts in tcpip.Stats structs.
IgnoreStats() bool
fmt.Stringer
}
// ErrBadAddress indicates a bad address was provided.
//
// +stateify savable
type ErrBadAddress struct{}
func (*ErrBadAddress) isError() {}
// IgnoreStats implements Error.
func (*ErrBadAddress) IgnoreStats() bool {
return false
}
func (*ErrBadAddress) String() string { return "bad address" }

View File

@@ -1,107 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package header provides the implementation of the encoding and decoding of
// network protocol headers.
package header
import (
"encoding/binary"
"fmt"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
)
// PseudoHeaderChecksum calculates the pseudo-header checksum for the given
// destination protocol and network address. Pseudo-headers are needed by
// transport layers when calculating their own checksum.
func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr []byte, dstAddr []byte, totalLen uint16) uint16 {
xsum := checksum.Checksum(srcAddr, 0)
xsum = checksum.Checksum(dstAddr, xsum)
// Add the length portion of the checksum to the pseudo-checksum.
var tmp [2]byte
binary.BigEndian.PutUint16(tmp[:], totalLen)
xsum = checksum.Checksum(tmp[:], xsum)
return checksum.Checksum([]byte{0, uint8(protocol)}, xsum)
}
// checksumUpdate2ByteAlignedUint16 updates a uint16 value in a calculated
// checksum.
//
// The value MUST begin at a 2-byte boundary in the original buffer.
func checksumUpdate2ByteAlignedUint16(xsum, old, new uint16) uint16 {
// As per RFC 1071 page 4,
// (4) Incremental Update
//
// ...
//
// To update the checksum, simply add the differences of the
// sixteen bit integers that have been changed. To see why this
// works, observe that every 16-bit integer has an additive inverse
// and that addition is associative. From this it follows that
// given the original value m, the new value m', and the old
// checksum C, the new checksum C' is:
//
// C' = C + (-m) + m' = C + (m' - m)
if old == new {
return xsum
}
return checksum.Combine(xsum, checksum.Combine(new, ^old))
}
// checksumUpdate2ByteAlignedAddress updates an address in a calculated
// checksum.
//
// The addresses must have the same length and must contain an even number
// of bytes. The address MUST begin at a 2-byte boundary in the original buffer.
func checksumUpdate2ByteAlignedAddress(xsum uint16, old, new tcpip.Address) uint16 {
const uint16Bytes = 2
if old.BitLen() != new.BitLen() {
panic(fmt.Sprintf("buffer lengths are different; old = %d, new = %d", old.BitLen()/8, new.BitLen()/8))
}
if oldBytes := old.BitLen() % 16; oldBytes != 0 {
panic(fmt.Sprintf("buffer has an odd number of bytes; got = %d", oldBytes))
}
oldAddr := old.AsSlice()
newAddr := new.AsSlice()
// As per RFC 1071 page 4,
// (4) Incremental Update
//
// ...
//
// To update the checksum, simply add the differences of the
// sixteen bit integers that have been changed. To see why this
// works, observe that every 16-bit integer has an additive inverse
// and that addition is associative. From this it follows that
// given the original value m, the new value m', and the old
// checksum C, the new checksum C' is:
//
// C' = C + (-m) + m' = C + (m' - m)
for len(oldAddr) != 0 {
// Convert the 2 byte sequences to uint16 values then apply the increment
// update.
xsum = checksumUpdate2ByteAlignedUint16(xsum, (uint16(oldAddr[0])<<8)+uint16(oldAddr[1]), (uint16(newAddr[0])<<8)+uint16(newAddr[1]))
oldAddr = oldAddr[uint16Bytes:]
newAddr = newAddr[uint16Bytes:]
}
return xsum
}

View File

@@ -1,192 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"github.com/sagernet/sing-tun/internal/gtcpip"
)
const (
dstMAC = 0
srcMAC = 6
ethType = 12
)
// EthernetFields contains the fields of an ethernet frame header. It is used to
// describe the fields of a frame that needs to be encoded.
type EthernetFields struct {
// SrcAddr is the "MAC source" field of an ethernet frame header.
SrcAddr tcpip.LinkAddress
// DstAddr is the "MAC destination" field of an ethernet frame header.
DstAddr tcpip.LinkAddress
// Type is the "ethertype" field of an ethernet frame header.
Type tcpip.NetworkProtocolNumber
}
// Ethernet represents an ethernet frame header stored in a byte array.
type Ethernet []byte
const (
// EthernetMinimumSize is the minimum size of a valid ethernet frame.
EthernetMinimumSize = 14
// EthernetMaximumSize is the maximum size of a valid ethernet frame.
EthernetMaximumSize = 18
// EthernetAddressSize is the size, in bytes, of an ethernet address.
EthernetAddressSize = 6
// UnspecifiedEthernetAddress is the unspecified ethernet address
// (all bits set to 0).
UnspecifiedEthernetAddress = tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00")
// EthernetBroadcastAddress is an ethernet address that addresses every node
// on a local link.
EthernetBroadcastAddress = tcpip.LinkAddress("\xff\xff\xff\xff\xff\xff")
// unicastMulticastFlagMask is the mask of the least significant bit in
// the first octet (in network byte order) of an ethernet address that
// determines whether the ethernet address is a unicast or multicast. If
// the masked bit is a 1, then the address is a multicast, unicast
// otherwise.
//
// See the IEEE Std 802-2001 document for more details. Specifically,
// section 9.2.1 of http://ieee802.org/secmail/pdfocSP2xXA6d.pdf:
// "A 48-bit universal address consists of two parts. The first 24 bits
// correspond to the OUI as assigned by the IEEE, expect that the
// assignee may set the LSB of the first octet to 1 for group addresses
// or set it to 0 for individual addresses."
unicastMulticastFlagMask = 1
// unicastMulticastFlagByteIdx is the byte that holds the
// unicast/multicast flag. See unicastMulticastFlagMask.
unicastMulticastFlagByteIdx = 0
)
const (
// EthernetProtocolAll is a catch-all for all protocols carried inside
// an ethernet frame. It is mainly used to create packet sockets that
// capture all traffic.
EthernetProtocolAll tcpip.NetworkProtocolNumber = 0x0003
// EthernetProtocolPUP is the PARC Universal Packet protocol ethertype.
EthernetProtocolPUP tcpip.NetworkProtocolNumber = 0x0200
)
// Ethertypes holds the protocol numbers describing the payload of an ethernet
// frame. These types aren't necessarily supported by netstack, but can be used
// to catch all traffic of a type via packet endpoints.
var Ethertypes = []tcpip.NetworkProtocolNumber{
EthernetProtocolAll,
EthernetProtocolPUP,
}
// SourceAddress returns the "MAC source" field of the ethernet frame header.
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
}
// DestinationAddress returns the "MAC destination" field of the ethernet frame
// header.
func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
}
// Type returns the "ethertype" field of the ethernet frame header.
func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
}
// Encode encodes all the fields of the ethernet frame header.
func (b Ethernet) Encode(e *EthernetFields) {
binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}
// IsMulticastEthernetAddress returns true if the address is a multicast
// ethernet address.
func IsMulticastEthernetAddress(addr tcpip.LinkAddress) bool {
if len(addr) != EthernetAddressSize {
return false
}
return addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0
}
// IsValidUnicastEthernetAddress returns true if the address is a unicast
// ethernet address.
func IsValidUnicastEthernetAddress(addr tcpip.LinkAddress) bool {
if len(addr) != EthernetAddressSize {
return false
}
if addr == UnspecifiedEthernetAddress {
return false
}
if addr[unicastMulticastFlagByteIdx]&unicastMulticastFlagMask != 0 {
return false
}
return true
}
// EthernetAddressFromMulticastIPv4Address returns a multicast Ethernet address
// for a multicast IPv4 address.
//
// addr MUST be a multicast IPv4 address.
func EthernetAddressFromMulticastIPv4Address(addr tcpip.Address) tcpip.LinkAddress {
var linkAddrBytes [EthernetAddressSize]byte
// RFC 1112 Host Extensions for IP Multicasting
//
// 6.4. Extensions to an Ethernet Local Network Module:
//
// An IP host group address is mapped to an Ethernet multicast
// address by placing the low-order 23-bits of the IP address
// into the low-order 23 bits of the Ethernet multicast address
// 01-00-5E-00-00-00 (hex).
addrBytes := addr.As4()
linkAddrBytes[0] = 0x1
linkAddrBytes[2] = 0x5e
linkAddrBytes[3] = addrBytes[1] & 0x7F
copy(linkAddrBytes[4:], addrBytes[IPv4AddressSize-2:])
return tcpip.LinkAddress(linkAddrBytes[:])
}
// EthernetAddressFromMulticastIPv6Address returns a multicast Ethernet address
// for a multicast IPv6 address.
//
// addr MUST be a multicast IPv6 address.
func EthernetAddressFromMulticastIPv6Address(addr tcpip.Address) tcpip.LinkAddress {
// RFC 2464 Transmission of IPv6 Packets over Ethernet Networks
//
// 7. Address Mapping -- Multicast
//
// An IPv6 packet with a multicast destination address DST,
// consisting of the sixteen octets DST[1] through DST[16], is
// transmitted to the Ethernet multicast address whose first
// two octets are the value 3333 hexadecimal and whose last
// four octets are the last four octets of DST.
addrBytes := addr.As16()
linkAddrBytes := []byte(addrBytes[IPv6AddressSize-EthernetAddressSize:])
linkAddrBytes[0] = 0x33
linkAddrBytes[1] = 0x33
return tcpip.LinkAddress(linkAddrBytes[:])
}

View File

@@ -1,228 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
type ICMPv4 []byte
const (
// ICMPv4PayloadOffset defines the start of ICMP payload.
ICMPv4PayloadOffset = 8
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 8
// ICMPv4MinimumErrorPayloadSize Is the smallest number of bytes of an
// errant packet's transport layer that an ICMP error type packet should
// attempt to send as per RFC 792 (see each type) and RFC 1122
// section 3.2.2 which states:
// Every ICMP error message includes the Internet header and at
// least the first 8 data octets of the datagram that triggered
// the error; more than 8 octets MAY be sent; this header and data
// MUST be unchanged from the received datagram.
//
// RFC 792 shows:
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type | Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | unused |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Internet Header + 64 bits of Original Data Datagram |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
ICMPv4MinimumErrorPayloadSize = 8
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
// icmpv4ChecksumOffset is the offset of the checksum field
// in an ICMPv4 message.
icmpv4ChecksumOffset = 2
// icmpv4MTUOffset is the offset of the MTU field
// in an ICMPv4FragmentationNeeded message.
icmpv4MTUOffset = 6
// icmpv4IdentOffset is the offset of the ident field
// in an ICMPv4EchoRequest/Reply message.
icmpv4IdentOffset = 4
// icmpv4PointerOffset is the offset of the pointer field
// in an ICMPv4ParamProblem message.
icmpv4PointerOffset = 4
// icmpv4SequenceOffset is the offset of the sequence field
// in an ICMPv4EchoRequest/Reply message.
icmpv4SequenceOffset = 6
)
// ICMPv4Type is the ICMP type field described in RFC 792.
type ICMPv4Type byte
// ICMPv4Code is the ICMP code field described in RFC 792.
type ICMPv4Code byte
// Typical values of ICMPv4Type defined in RFC 792.
const (
ICMPv4EchoReply ICMPv4Type = 0
ICMPv4DstUnreachable ICMPv4Type = 3
ICMPv4SrcQuench ICMPv4Type = 4
ICMPv4Redirect ICMPv4Type = 5
ICMPv4Echo ICMPv4Type = 8
ICMPv4TimeExceeded ICMPv4Type = 11
ICMPv4ParamProblem ICMPv4Type = 12
ICMPv4Timestamp ICMPv4Type = 13
ICMPv4TimestampReply ICMPv4Type = 14
ICMPv4InfoRequest ICMPv4Type = 15
ICMPv4InfoReply ICMPv4Type = 16
)
// ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792.
const (
ICMPv4TTLExceeded ICMPv4Code = 0
ICMPv4ReassemblyTimeout ICMPv4Code = 1
)
// ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792,
// RFC 1122 section 3.2.2.1 and RFC 1812 section 5.2.7.1.
const (
ICMPv4NetUnreachable ICMPv4Code = 0
ICMPv4HostUnreachable ICMPv4Code = 1
ICMPv4ProtoUnreachable ICMPv4Code = 2
ICMPv4PortUnreachable ICMPv4Code = 3
ICMPv4FragmentationNeeded ICMPv4Code = 4
ICMPv4SourceRouteFailed ICMPv4Code = 5
ICMPv4DestinationNetworkUnknown ICMPv4Code = 6
ICMPv4DestinationHostUnknown ICMPv4Code = 7
ICMPv4SourceHostIsolated ICMPv4Code = 8
ICMPv4NetProhibited ICMPv4Code = 9
ICMPv4HostProhibited ICMPv4Code = 10
ICMPv4NetUnreachableForTos ICMPv4Code = 11
ICMPv4HostUnreachableForTos ICMPv4Code = 12
ICMPv4AdminProhibited ICMPv4Code = 13
ICMPv4HostPrecedenceViolation ICMPv4Code = 14
ICMPv4PrecedenceCutInEffect ICMPv4Code = 15
)
// ICMPv4UnusedCode is a code to use in ICMP messages where no code is needed.
const ICMPv4UnusedCode ICMPv4Code = 0
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
// SetType sets the ICMP type field.
func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) }
// Pointer returns the pointer field in a Parameter Problem packet.
func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] }
// SetPointer sets the pointer field in a Parameter Problem packet.
func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(cs uint16) {
checksum.Put(b[icmpv4ChecksumOffset:], cs)
}
// SourcePort implements Transport.SourcePort.
func (ICMPv4) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (ICMPv4) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (ICMPv4) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (ICMPv4) SetDestinationPort(uint16) {
}
// Payload implements Transport.Payload.
func (b ICMPv4) Payload() []byte {
return b[ICMPv4PayloadOffset:]
}
// MTU retrieves the MTU field from an ICMPv4 message.
func (b ICMPv4) MTU() uint16 {
return binary.BigEndian.Uint16(b[icmpv4MTUOffset:])
}
// SetMTU sets the MTU field from an ICMPv4 message.
func (b ICMPv4) SetMTU(mtu uint16) {
binary.BigEndian.PutUint16(b[icmpv4MTUOffset:], mtu)
}
// Ident retrieves the Ident field from an ICMPv4 message.
func (b ICMPv4) Ident() uint16 {
return binary.BigEndian.Uint16(b[icmpv4IdentOffset:])
}
// SetIdent sets the Ident field from an ICMPv4 message.
func (b ICMPv4) SetIdent(ident uint16) {
binary.BigEndian.PutUint16(b[icmpv4IdentOffset:], ident)
}
// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum.
func (b ICMPv4) SetIdentWithChecksumUpdate(new uint16) {
old := b.Ident()
b.SetIdent(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// Sequence retrieves the Sequence field from an ICMPv4 message.
func (b ICMPv4) Sequence() uint16 {
return binary.BigEndian.Uint16(b[icmpv4SequenceOffset:])
}
// SetSequence sets the Sequence field from an ICMPv4 message.
func (b ICMPv4) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv4SequenceOffset:], sequence)
}
// ICMPv4Checksum calculates the ICMP checksum over the provided ICMP header,
// and payload.
func ICMPv4Checksum(h ICMPv4, payloadCsum uint16) uint16 {
xsum := payloadCsum
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = checksum.Checksum(h[:2], xsum)
xsum = checksum.Checksum(h[4:], xsum)
return ^xsum
}

View File

@@ -1,304 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
type ICMPv6 []byte
const (
// ICMPv6HeaderSize is the size of the ICMPv6 header. That is, the
// sum of the size of the ICMPv6 Type, Code and Checksum fields, as
// per RFC 4443 section 2.1. After the ICMPv6 header, the ICMPv6
// message body begins.
ICMPv6HeaderSize = 4
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
ICMPv6MinimumSize = 8
// ICMPv6PayloadOffset is the offset of the payload in an
// ICMP packet.
ICMPv6PayloadOffset = 8
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
ICMPv6NeighborSolicitMinimumSize = ICMPv6HeaderSize + NDPNSMinimumSize
// ICMPv6NeighborAdvertMinimumSize is the minimum size of a
// neighbor advertisement packet.
ICMPv6NeighborAdvertMinimumSize = ICMPv6HeaderSize + NDPNAMinimumSize
// ICMPv6EchoMinimumSize is the minimum size of a valid echo packet.
ICMPv6EchoMinimumSize = 8
// ICMPv6ErrorHeaderSize is the size of an ICMP error packet header,
// as per RFC 4443, Appendix A, item 4 and the errata.
// ... all ICMP error messages shall have exactly
// 32 bits of type-specific data, so that receivers can reliably find
// the embedded invoking packet even when they don't recognize the
// ICMP message Type.
ICMPv6ErrorHeaderSize = 8
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
// packet-too-big packet.
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize
// ICMPv6ChecksumOffset is the offset of the checksum field
// in an ICMPv6 message.
ICMPv6ChecksumOffset = 2
// icmpv6PointerOffset is the offset of the pointer
// in an ICMPv6 Parameter problem message.
icmpv6PointerOffset = 4
// icmpv6MTUOffset is the offset of the MTU field in an ICMPv6
// PacketTooBig message.
icmpv6MTUOffset = 4
// icmpv6IdentOffset is the offset of the ident field
// in a ICMPv6 Echo Request/Reply message.
icmpv6IdentOffset = 4
// icmpv6SequenceOffset is the offset of the sequence field
// in a ICMPv6 Echo Request/Reply message.
icmpv6SequenceOffset = 6
// NDPHopLimit is the expected IP hop limit value of 255 for received
// NDP packets, as per RFC 4861 sections 4.1 - 4.5, 6.1.1, 6.1.2, 7.1.1,
// 7.1.2 and 8.1. If the hop limit value is not 255, nodes MUST silently
// drop the NDP packet. All outgoing NDP packets must use this value for
// its IP hop limit field.
NDPHopLimit = 255
)
// ICMPv6Type is the ICMP type field described in RFC 4443.
type ICMPv6Type byte
// Values for use in the Type field of ICMPv6 packet from RFC 4433.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
ICMPv6TimeExceeded ICMPv6Type = 3
ICMPv6ParamProblem ICMPv6Type = 4
ICMPv6EchoRequest ICMPv6Type = 128
ICMPv6EchoReply ICMPv6Type = 129
// Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
ICMPv6RouterSolicit ICMPv6Type = 133
ICMPv6RouterAdvert ICMPv6Type = 134
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
// Multicast Listener Discovery (MLD) messages, see RFC 2710.
ICMPv6MulticastListenerQuery ICMPv6Type = 130
ICMPv6MulticastListenerReport ICMPv6Type = 131
ICMPv6MulticastListenerDone ICMPv6Type = 132
// Multicast Listener Discovert Version 2 (MLDv2) messages, see RFC 3810.
ICMPv6MulticastListenerV2Report ICMPv6Type = 143
)
// IsErrorType returns true if the receiver is an ICMP error type.
func (typ ICMPv6Type) IsErrorType() bool {
// Per RFC 4443 section 2.1:
// ICMPv6 messages are grouped into two classes: error messages and
// informational messages. Error messages are identified as such by a
// zero in the high-order bit of their message Type field values. Thus,
// error messages have message types from 0 to 127; informational
// messages have message types from 128 to 255.
return typ&0x80 == 0
}
// ICMPv6Code is the ICMP Code field described in RFC 4443.
type ICMPv6Code byte
// ICMP codes used with Destination Unreachable (Type 1). As per RFC 4443
// section 3.1.
const (
ICMPv6NetworkUnreachable ICMPv6Code = 0
ICMPv6Prohibited ICMPv6Code = 1
ICMPv6BeyondScope ICMPv6Code = 2
ICMPv6AddressUnreachable ICMPv6Code = 3
ICMPv6PortUnreachable ICMPv6Code = 4
ICMPv6Policy ICMPv6Code = 5
ICMPv6RejectRoute ICMPv6Code = 6
)
// ICMP codes used with Time Exceeded (Type 3). As per RFC 4443 section 3.3.
const (
ICMPv6HopLimitExceeded ICMPv6Code = 0
ICMPv6ReassemblyTimeout ICMPv6Code = 1
)
// ICMP codes used with Parameter Problem (Type 4). As per RFC 4443 section 3.4.
const (
// ICMPv6ErroneousHeader indicates an erroneous header field was encountered.
ICMPv6ErroneousHeader ICMPv6Code = 0
// ICMPv6UnknownHeader indicates an unrecognized Next Header type encountered.
ICMPv6UnknownHeader ICMPv6Code = 1
// ICMPv6UnknownOption indicates an unrecognized IPv6 option was encountered.
ICMPv6UnknownOption ICMPv6Code = 2
)
// ICMPv6UnusedCode is the code value used with ICMPv6 messages which don't use
// the code field. (Types not mentioned above.)
const ICMPv6UnusedCode ICMPv6Code = 0
// Type is the ICMP type field.
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
// SetType sets the ICMP type field.
func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv6) Code() ICMPv6Code { return ICMPv6Code(b[1]) }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c ICMPv6Code) { b[1] = byte(c) }
// TypeSpecific returns the type specific data field.
func (b ICMPv6) TypeSpecific() uint32 {
return binary.BigEndian.Uint32(b[icmpv6PointerOffset:])
}
// SetTypeSpecific sets the type specific data field.
func (b ICMPv6) SetTypeSpecific(val uint32) {
binary.BigEndian.PutUint32(b[icmpv6PointerOffset:], val)
}
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
return binary.BigEndian.Uint16(b[ICMPv6ChecksumOffset:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(cs uint16) {
checksum.Put(b[ICMPv6ChecksumOffset:], cs)
}
// SourcePort implements Transport.SourcePort.
func (ICMPv6) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (ICMPv6) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (ICMPv6) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (ICMPv6) SetDestinationPort(uint16) {
}
// MTU retrieves the MTU field from an ICMPv6 message.
func (b ICMPv6) MTU() uint32 {
return binary.BigEndian.Uint32(b[icmpv6MTUOffset:])
}
// SetMTU sets the MTU field from an ICMPv6 message.
func (b ICMPv6) SetMTU(mtu uint32) {
binary.BigEndian.PutUint32(b[icmpv6MTUOffset:], mtu)
}
// Ident retrieves the Ident field from an ICMPv6 message.
func (b ICMPv6) Ident() uint16 {
return binary.BigEndian.Uint16(b[icmpv6IdentOffset:])
}
// SetIdent sets the Ident field from an ICMPv6 message.
func (b ICMPv6) SetIdent(ident uint16) {
binary.BigEndian.PutUint16(b[icmpv6IdentOffset:], ident)
}
// SetIdentWithChecksumUpdate sets the Ident field and updates the checksum.
func (b ICMPv6) SetIdentWithChecksumUpdate(new uint16) {
old := b.Ident()
b.SetIdent(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// Sequence retrieves the Sequence field from an ICMPv6 message.
func (b ICMPv6) Sequence() uint16 {
return binary.BigEndian.Uint16(b[icmpv6SequenceOffset:])
}
// SetSequence sets the Sequence field from an ICMPv6 message.
func (b ICMPv6) SetSequence(sequence uint16) {
binary.BigEndian.PutUint16(b[icmpv6SequenceOffset:], sequence)
}
// MessageBody returns the message body as defined by RFC 4443 section 2.1; the
// portion of the ICMPv6 buffer after the first ICMPv6HeaderSize bytes.
func (b ICMPv6) MessageBody() []byte {
return b[ICMPv6HeaderSize:]
}
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
return b[ICMPv6PayloadOffset:]
}
// ICMPv6ChecksumParams contains parameters to calculate ICMPv6 checksum.
type ICMPv6ChecksumParams struct {
Header ICMPv6
Src tcpip.Address
Dst tcpip.Address
PayloadCsum uint16
PayloadLen int
}
// ICMPv6Checksum calculates the ICMP checksum over the provided ICMPv6 header,
// IPv6 src/dst addresses and the payload.
func ICMPv6Checksum(params ICMPv6ChecksumParams) uint16 {
h := params.Header
xsum := PseudoHeaderChecksum(ICMPv6ProtocolNumber, params.Src.AsSlice(), params.Dst.AsSlice(), uint16(len(h)+params.PayloadLen))
xsum = checksum.Combine(xsum, params.PayloadCsum)
// h[2:4] is the checksum itself, skip it to avoid checksumming the checksum.
xsum = checksum.Checksum(h[:2], xsum)
xsum = checksum.Checksum(h[4:], xsum)
return ^xsum
}
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
func (b ICMPv6) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address) {
b.SetChecksum(^checksumUpdate2ByteAlignedAddress(^b.Checksum(), old, new))
}

View File

@@ -1,136 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"net/netip"
tcpip "github.com/sagernet/sing-tun/internal/gtcpip"
)
const (
// MaxIPPacketSize is the maximum supported IP packet size, excluding
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
// m is the minimum IPv6 header size; we leave room for some potential
// IP options.
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
)
// Transport offers generic methods to query and/or update the fields of the
// header of a transport protocol buffer.
type Transport interface {
// SourcePort returns the value of the "source port" field.
SourcePort() uint16
// Destination returns the value of the "destination port" field.
DestinationPort() uint16
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourcePort sets the value of the "source port" field.
SetSourcePort(uint16)
// SetDestinationPort sets the value of the "destination port" field.
SetDestinationPort(uint16)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// Payload returns the data carried in the transport buffer.
Payload() []byte
}
// ChecksummableTransport is a Transport that supports checksumming.
type ChecksummableTransport interface {
Transport
// SetSourcePortWithChecksumUpdate sets the source port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetSourcePortWithChecksumUpdate(port uint16)
// SetDestinationPortWithChecksumUpdate sets the destination port and updates
// the checksum.
//
// The receiver's checksum must be a fully calculated checksum.
SetDestinationPortWithChecksumUpdate(port uint16)
// UpdateChecksumPseudoHeaderAddress updates the checksum to reflect an
// updated address in the pseudo header.
//
// If fullChecksum is true, the receiver's checksum field is assumed to hold a
// fully calculated checksum. Otherwise, it is assumed to hold a partially
// calculated checksum which only reflects the pseudo header.
UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool)
}
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
DestinationAddr() netip.Addr
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)
SetDestinationAddr(addr netip.Addr)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// TransportProtocol returns the number of the transport protocol
// stored in the payload.
TransportProtocol() tcpip.TransportProtocolNumber
// Payload returns a byte slice containing the payload of the network
// packet.
Payload() []byte
// TOS returns the values of the "type of service" and "flow label" fields.
TOS() (uint8, uint32)
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}
// ChecksummableNetwork is a Network that supports checksumming.
type ChecksummableNetwork interface {
Network
// SetSourceAddressAndChecksum sets the source address and updates the
// checksum to reflect the new address.
SetSourceAddressWithChecksumUpdate(tcpip.Address)
// SetDestinationAddressAndChecksum sets the destination address and
// updates the checksum to reflect the new address.
SetDestinationAddressWithChecksumUpdate(tcpip.Address)
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,578 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"crypto/sha256"
"encoding/binary"
"fmt"
"net/netip"
"github.com/sagernet/sing-tun/internal/gtcpip"
)
const (
versTCFL = 0
// IPv6PayloadLenOffset is the offset of the PayloadLength field in
// IPv6 header.
IPv6PayloadLenOffset = 4
// IPv6NextHeaderOffset is the offset of the NextHeader field in
// IPv6 header.
IPv6NextHeaderOffset = 6
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = v6SrcAddr + IPv6AddressSize
// IPv6FixedHeaderSize is the size of the fixed header.
IPv6FixedHeaderSize = v6DstAddr + IPv6AddressSize
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv6Fields struct {
// TrafficClass is the "traffic class" field of an IPv6 packet.
TrafficClass uint8
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
// PayloadLength is the "payload length" field of an IPv6 packet, including
// the length of all extension headers.
PayloadLength uint16
// TransportProtocol is the transport layer protocol number. Serialized in the
// last "next header" field of the IPv6 header + extension headers.
TransportProtocol tcpip.TransportProtocolNumber
// HopLimit is the "Hop Limit" field of an IPv6 packet.
HopLimit uint8
// SrcAddr is the "source ip address" of an IPv6 packet.
SrcAddr netip.Addr
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr netip.Addr
// ExtensionHeaders are the extension headers following the IPv6 header.
ExtensionHeaders IPv6ExtHdrSerializer
}
// IPv6 represents an ipv6 header stored in a byte array.
// Most of the methods of IPv6 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv6 before using other methods.
type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
IPv6MinimumSize = IPv6FixedHeaderSize
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
// IPv6AddressSizeBits is the size, in bits, of an IPv6 address.
IPv6AddressSizeBits = 128
// IPv6MaximumPayloadSize is the maximum size of a valid IPv6 payload per
// RFC 8200 Section 4.5.
IPv6MaximumPayloadSize = 65535
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
// IIDSize is the size of an interface identifier (IID), in bytes, as
// defined by RFC 4291 section 2.5.1.
IIDSize = 8
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200,
// section 5:
// IPv6 requires that every link in the Internet have an MTU of 1280 octets
// or greater. This is known as the IPv6 minimum link MTU.
IPv6MinimumMTU = 1280
// IIDOffsetInIPv6Address is the offset, in bytes, from the start
// of an IPv6 address to the beginning of the interface identifier
// (IID) for auto-generated addresses. That is, all bytes before
// the IIDOffsetInIPv6Address-th byte are the prefix bytes, and all
// bytes including and after the IIDOffsetInIPv6Address-th byte are
// for the IID.
IIDOffsetInIPv6Address = 8
// OpaqueIIDSecretKeyMinBytes is the recommended minimum number of bytes
// for the secret key used to generate an opaque interface identifier as
// outlined by RFC 7217.
OpaqueIIDSecretKeyMinBytes = 16
// ipv6MulticastAddressScopeByteIdx is the byte where the scope (scop) field
// is located within a multicast IPv6 address, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeByteIdx = 1
// ipv6MulticastAddressScopeMask is the mask for the scope (scop) field,
// within the byte holding the field, as per RFC 4291 section 2.7.
ipv6MulticastAddressScopeMask = 0xF
)
var (
// IPv6AllNodesMulticastAddress is a link-local multicast group that
// all IPv6 nodes MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all nodes on a link.
//
// The address is ff02::1.
IPv6AllNodesMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
// IPv6AllRoutersInterfaceLocalMulticastAddress is an interface-local
// multicast group that all IPv6 routers MUST join, as per RFC 4291, section
// 2.8. Packets destined to this address will reach the router on an
// interface.
//
// The address is ff01::2.
IPv6AllRoutersInterfaceLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
// IPv6AllRoutersLinkLocalMulticastAddress is a link-local multicast group
// that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers on a link.
//
// The address is ff02::2.
IPv6AllRoutersLinkLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
// IPv6AllRoutersSiteLocalMulticastAddress is a site-local multicast group
// that all IPv6 routers MUST join, as per RFC 4291, section 2.8. Packets
// destined to this address will reach all routers in a site.
//
// The address is ff05::2.
IPv6AllRoutersSiteLocalMulticastAddress = tcpip.AddrFrom16([16]byte{0xff, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02})
// IPv6Loopback is the IPv6 Loopback address.
IPv6Loopback = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
// IPv6Any is the non-routable IPv6 "any" meta address. It is also
// known as the unspecified address.
IPv6Any = tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
)
// IPv6EmptySubnet is the empty IPv6 subnet. It may also be known as the
// catch-all or wildcard subnet. That is, all IPv6 addresses are considered to
// be contained within this subnet.
var IPv6EmptySubnet = tcpip.AddressWithPrefix{
Address: IPv6Any,
PrefixLen: 0,
}.Subnet()
// IPv4MappedIPv6Subnet is the prefix for an IPv4 mapped IPv6 address as defined
// by RFC 4291 section 2.5.5.
var IPv4MappedIPv6Subnet = tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16([16]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}),
PrefixLen: 96,
}.Subnet()
// IPv6LinkLocalPrefix is the prefix for IPv6 link-local addresses, as defined
// by RFC 4291 section 2.5.6.
//
// The prefix is fe80::/64
var IPv6LinkLocalPrefix = tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16([16]byte{0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}),
PrefixLen: 64,
}
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[IPv6PayloadLenOffset:])
}
// HopLimit returns the value of the "Hop Limit" field of the ipv6 header.
func (b IPv6) HopLimit() uint8 {
return b[hopLimit]
}
// NextHeader returns the value of the "next header" field of the ipv6 header.
func (b IPv6) NextHeader() uint8 {
return b[IPv6NextHeaderOffset]
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.NextHeader())
}
// Payload implements Network.Payload.
func (b IPv6) Payload() []byte {
return b[IPv6MinimumSize:][:b.PayloadLength()]
}
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
return tcpip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize]))
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
return tcpip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize]))
}
// SourceAddressSlice returns the "source address" field of the ipv6 header as a
// byte slice.
func (b IPv6) SourceAddressSlice() []byte {
return []byte(b[v6SrcAddr:][:IPv6AddressSize])
}
// DestinationAddressSlice returns the "destination address" field of the ipv6
// header as a byte slice.
func (b IPv6) DestinationAddressSlice() []byte {
return []byte(b[v6DstAddr:][:IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
// checksum, it just returns 0.
func (IPv6) Checksum() uint16 {
return 0
}
// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
func (b IPv6) TOS() (uint8, uint32) {
v := binary.BigEndian.Uint32(b[versTCFL:])
return uint8(v >> 20), v & 0xfffff
}
// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
func (b IPv6) SetTOS(t uint8, l uint32) {
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
}
// SetPayloadLength sets the "payload length" field of the ipv6 header.
func (b IPv6) SetPayloadLength(payloadLength uint16) {
binary.BigEndian.PutUint16(b[IPv6PayloadLenOffset:], payloadLength)
}
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
}
// SetHopLimit sets the value of the "Hop Limit" field.
func (b IPv6) SetHopLimit(v uint8) {
b[hopLimit] = v
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
b[IPv6NextHeaderOffset] = v
}
// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
// checksum, it is empty.
func (IPv6) SetChecksum(uint16) {
}
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
extHdr := b[IPv6MinimumSize:]
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
b[hopLimit] = i.HopLimit
b.SetSourceAddr(i.SrcAddr)
b.SetDestinationAddr(i.DstAddr)
nextHeader, _ := i.ExtensionHeaders.Serialize(i.TransportProtocol, extHdr)
b[IPv6NextHeaderOffset] = nextHeader
}
// IsValid performs basic validation on the packet.
func (b IPv6) IsValid(pktSize int) bool {
if len(b) < IPv6MinimumSize {
return false
}
dlen := int(b.PayloadLength())
if dlen > pktSize-IPv6MinimumSize {
return false
}
if IPVersion(b) != IPv6Version {
return false
}
return true
}
// IsV4MappedAddress determines if the provided address is an IPv4 mapped
// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
func IsV4MappedAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
return IPv4MappedIPv6Subnet.Contains(addr)
}
// IsV6MulticastAddress determines if the provided address is an IPv6
// multicast address (anything starting with FF).
func IsV6MulticastAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
return addr.As16()[0] == 0xff
}
// IsV6UnicastAddress determines if the provided address is a valid IPv6
// unicast (and specified) address. That is, IsV6UnicastAddress returns
// true if addr contains IPv6AddressSize bytes, is not the unspecified
// address and is not a multicast address.
func IsV6UnicastAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
// Must not be unspecified
if addr == IPv6Any {
return false
}
// Return if not a multicast.
return addr.As16()[0] != 0xff
}
var solicitedNodeMulticastPrefix = [13]byte{0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff}
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
addrBytes := addr.As16()
return tcpip.AddrFrom16([16]byte(append(solicitedNodeMulticastPrefix[:], addrBytes[len(addrBytes)-3:]...)))
}
// IsSolicitedNodeAddr determines whether the address is a solicited-node
// multicast address.
func IsSolicitedNodeAddr(addr tcpip.Address) bool {
addrBytes := addr.As16()
return solicitedNodeMulticastPrefix == [13]byte(addrBytes[:len(addrBytes)-3])
}
// EthernetAdddressToModifiedEUI64IntoBuf populates buf with a modified EUI-64
// from a 48-bit Ethernet/MAC address, as per RFC 4291 section 2.5.1.
//
// buf MUST be at least 8 bytes.
func EthernetAdddressToModifiedEUI64IntoBuf(linkAddr tcpip.LinkAddress, buf []byte) {
buf[0] = linkAddr[0] ^ 2
buf[1] = linkAddr[1]
buf[2] = linkAddr[2]
buf[3] = 0xFF
buf[4] = 0xFE
buf[5] = linkAddr[3]
buf[6] = linkAddr[4]
buf[7] = linkAddr[5]
}
// EthernetAddressToModifiedEUI64 computes a modified EUI-64 from a 48-bit
// Ethernet/MAC address, as per RFC 4291 section 2.5.1.
func EthernetAddressToModifiedEUI64(linkAddr tcpip.LinkAddress) [IIDSize]byte {
var buf [IIDSize]byte
EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, buf[:])
return buf
}
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
// (MAC) address.
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
// Convert a 48-bit MAC to a modified EUI-64 and then prepend the
// link-local header, FE80::.
//
// The conversion is very nearly:
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
// Note the capital A. The conversion aa->Aa involves a bit flip.
lladdrb := [IPv6AddressSize]byte{
0: 0xFE,
1: 0x80,
}
EthernetAdddressToModifiedEUI64IntoBuf(linkAddr, lladdrb[IIDOffsetInIPv6Address:])
return tcpip.AddrFrom16(lladdrb)
}
// IsV6LinkLocalUnicastAddress returns true iff the provided address is an IPv6
// link-local unicast address, as defined by RFC 4291 section 2.5.6.
func IsV6LinkLocalUnicastAddress(addr tcpip.Address) bool {
if addr.BitLen() != IPv6AddressSizeBits {
return false
}
addrBytes := addr.As16()
return addrBytes[0] == 0xfe && (addrBytes[1]&0xc0) == 0x80
}
// IsV6LoopbackAddress returns true iff the provided address is an IPv6 loopback
// address, as defined by RFC 4291 section 2.5.3.
func IsV6LoopbackAddress(addr tcpip.Address) bool {
return addr == IPv6Loopback
}
// IsV6LinkLocalMulticastAddress returns true iff the provided address is an
// IPv6 link-local multicast address, as defined by RFC 4291 section 2.7.
func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool {
return IsV6MulticastAddress(addr) && V6MulticastScope(addr) == IPv6LinkLocalMulticastScope
}
// AppendOpaqueInterfaceIdentifier appends a 64 bit opaque interface identifier
// (IID) to buf as outlined by RFC 7217 and returns the extended buffer.
//
// The opaque IID is generated from the cryptographic hash of the concatenation
// of the prefix, NIC's name, DAD counter (DAD retry counter) and the secret
// key. The secret key SHOULD be at least OpaqueIIDSecretKeyMinBytes bytes and
// MUST be generated to a pseudo-random number. See RFC 4086 for randomness
// requirements for security.
//
// If buf has enough capacity for the IID (IIDSize bytes), a new underlying
// array for the buffer will not be allocated.
func AppendOpaqueInterfaceIdentifier(buf []byte, prefix tcpip.Subnet, nicName string, dadCounter uint8, secretKey []byte) []byte {
// As per RFC 7217 section 5, the opaque identifier can be generated as a
// cryptographic hash of the concatenation of each of the function parameters.
// Note, we omit the optional Network_ID field.
h := sha256.New()
// h.Write never returns an error.
prefixID := prefix.ID()
h.Write([]byte(prefixID.AsSlice()[:IIDOffsetInIPv6Address]))
h.Write([]byte(nicName))
h.Write([]byte{dadCounter})
h.Write(secretKey)
var sumBuf [sha256.Size]byte
sum := h.Sum(sumBuf[:0])
return append(buf, sum[:IIDSize]...)
}
// LinkLocalAddrWithOpaqueIID computes the default IPv6 link-local address with
// an opaque IID.
func LinkLocalAddrWithOpaqueIID(nicName string, dadCounter uint8, secretKey []byte) tcpip.Address {
lladdrb := [IPv6AddressSize]byte{
0: 0xFE,
1: 0x80,
}
return tcpip.AddrFrom16([16]byte(AppendOpaqueInterfaceIdentifier(lladdrb[:IIDOffsetInIPv6Address], IPv6LinkLocalPrefix.Subnet(), nicName, dadCounter, secretKey)))
}
// IPv6AddressScope is the scope of an IPv6 address.
type IPv6AddressScope int
const (
// LinkLocalScope indicates a link-local address.
LinkLocalScope IPv6AddressScope = iota
// GlobalScope indicates a global address.
GlobalScope
)
// ScopeForIPv6Address returns the scope for an IPv6 address.
func ScopeForIPv6Address(addr tcpip.Address) (IPv6AddressScope, tcpip.Error) {
if addr.BitLen() != IPv6AddressSizeBits {
return GlobalScope, &tcpip.ErrBadAddress{}
}
switch {
case IsV6LinkLocalMulticastAddress(addr):
return LinkLocalScope, nil
case IsV6LinkLocalUnicastAddress(addr):
return LinkLocalScope, nil
default:
return GlobalScope, nil
}
}
// GenerateTempIPv6SLAACAddr generates a temporary SLAAC IPv6 address for an
// associated stable/permanent SLAAC address.
//
// GenerateTempIPv6SLAACAddr will update the temporary IID history value to be
// used when generating a new temporary IID.
//
// Panics if tempIIDHistory is not at least IIDSize bytes.
func GenerateTempIPv6SLAACAddr(tempIIDHistory []byte, stableAddr tcpip.Address) tcpip.AddressWithPrefix {
addrBytes := stableAddr.As16()
h := sha256.New()
h.Write(tempIIDHistory)
h.Write(addrBytes[IIDOffsetInIPv6Address:])
var sumBuf [sha256.Size]byte
sum := h.Sum(sumBuf[:0])
// The rightmost 64 bits of sum are saved for the next iteration.
if n := copy(tempIIDHistory, sum[sha256.Size-IIDSize:]); n != IIDSize {
panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IIDSize))
}
// The leftmost 64 bits of sum is used as the IID.
if n := copy(addrBytes[IIDOffsetInIPv6Address:], sum); n != IIDSize {
panic(fmt.Sprintf("copied %d IID bytes, expected %d bytes", n, IIDSize))
}
return tcpip.AddressWithPrefix{
Address: tcpip.AddrFrom16(addrBytes),
PrefixLen: IIDOffsetInIPv6Address * 8,
}
}
// IPv6MulticastScope is the scope of a multicast IPv6 address, as defined by
// RFC 7346 section 2.
type IPv6MulticastScope uint8
// The various values for IPv6 multicast scopes, as per RFC 7346 section 2:
//
// +------+--------------------------+-------------------------+
// | scop | NAME | REFERENCE |
// +------+--------------------------+-------------------------+
// | 0 | Reserved | [RFC4291], RFC 7346 |
// | 1 | Interface-Local scope | [RFC4291], RFC 7346 |
// | 2 | Link-Local scope | [RFC4291], RFC 7346 |
// | 3 | Realm-Local scope | [RFC4291], RFC 7346 |
// | 4 | Admin-Local scope | [RFC4291], RFC 7346 |
// | 5 | Site-Local scope | [RFC4291], RFC 7346 |
// | 6 | Unassigned | |
// | 7 | Unassigned | |
// | 8 | Organization-Local scope | [RFC4291], RFC 7346 |
// | 9 | Unassigned | |
// | A | Unassigned | |
// | B | Unassigned | |
// | C | Unassigned | |
// | D | Unassigned | |
// | E | Global scope | [RFC4291], RFC 7346 |
// | F | Reserved | [RFC4291], RFC 7346 |
// +------+--------------------------+-------------------------+
const (
IPv6Reserved0MulticastScope = IPv6MulticastScope(0x0)
IPv6InterfaceLocalMulticastScope = IPv6MulticastScope(0x1)
IPv6LinkLocalMulticastScope = IPv6MulticastScope(0x2)
IPv6RealmLocalMulticastScope = IPv6MulticastScope(0x3)
IPv6AdminLocalMulticastScope = IPv6MulticastScope(0x4)
IPv6SiteLocalMulticastScope = IPv6MulticastScope(0x5)
IPv6OrganizationLocalMulticastScope = IPv6MulticastScope(0x8)
IPv6GlobalMulticastScope = IPv6MulticastScope(0xE)
IPv6ReservedFMulticastScope = IPv6MulticastScope(0xF)
)
// V6MulticastScope returns the scope of a multicast address.
func V6MulticastScope(addr tcpip.Address) IPv6MulticastScope {
addrBytes := addr.As16()
return IPv6MulticastScope(addrBytes[ipv6MulticastAddressScopeByteIdx] & ipv6MulticastAddressScopeMask)
}

View File

@@ -1,507 +0,0 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"errors"
"fmt"
"math"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing/common"
)
// IPv6ExtensionHeaderIdentifier is an IPv6 extension header identifier.
type IPv6ExtensionHeaderIdentifier uint8
const (
// IPv6HopByHopOptionsExtHdrIdentifier is the header identifier of a Hop by
// Hop Options extension header, as per RFC 8200 section 4.3.
IPv6HopByHopOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 0
// IPv6RoutingExtHdrIdentifier is the header identifier of a Routing extension
// header, as per RFC 8200 section 4.4.
IPv6RoutingExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 43
// IPv6FragmentExtHdrIdentifier is the header identifier of a Fragment
// extension header, as per RFC 8200 section 4.5.
IPv6FragmentExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 44
// IPv6DestinationOptionsExtHdrIdentifier is the header identifier of a
// Destination Options extension header, as per RFC 8200 section 4.6.
IPv6DestinationOptionsExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 60
// IPv6NoNextHeaderIdentifier is the header identifier used to signify the end
// of an IPv6 payload, as per RFC 8200 section 4.7.
IPv6NoNextHeaderIdentifier IPv6ExtensionHeaderIdentifier = 59
// IPv6UnknownExtHdrIdentifier is reserved by IANA.
// https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml#extension-header
// "254 Use for experimentation and testing [RFC3692][RFC4727]"
IPv6UnknownExtHdrIdentifier IPv6ExtensionHeaderIdentifier = 254
)
const (
// ipv6UnknownExtHdrOptionActionMask is the mask of the action to take when
// a node encounters an unrecognized option.
ipv6UnknownExtHdrOptionActionMask = 192
// ipv6UnknownExtHdrOptionActionShift is the least significant bits to discard
// from the action value for an unrecognized option identifier.
ipv6UnknownExtHdrOptionActionShift = 6
// ipv6RoutingExtHdrSegmentsLeftIdx is the index to the Segments Left field
// within an IPv6RoutingExtHdr.
ipv6RoutingExtHdrSegmentsLeftIdx = 1
// IPv6FragmentExtHdrLength is the length of an IPv6 extension header, in
// bytes.
IPv6FragmentExtHdrLength = 8
// ipv6FragmentExtHdrFragmentOffsetOffset is the offset to the start of the
// Fragment Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetOffset = 0
// ipv6FragmentExtHdrFragmentOffsetShift is the bit offset of the Fragment
// Offset field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrFragmentOffsetShift = 3
// ipv6FragmentExtHdrFlagsIdx is the index to the flags field within an
// IPv6FragmentExtHdr.
ipv6FragmentExtHdrFlagsIdx = 1
// ipv6FragmentExtHdrMFlagMask is the mask of the More (M) flag within the
// flags field of an IPv6FragmentExtHdr.
ipv6FragmentExtHdrMFlagMask = 1
// ipv6FragmentExtHdrIdentificationOffset is the offset to the Identification
// field within an IPv6FragmentExtHdr.
ipv6FragmentExtHdrIdentificationOffset = 2
// ipv6ExtHdrLenBytesPerUnit is the unit size of an extension header's length
// field. That is, given a Length field of 2, the extension header expects
// 16 bytes following the first 8 bytes (see ipv6ExtHdrLenBytesExcluded for
// details about the first 8 bytes' exclusion from the Length field).
ipv6ExtHdrLenBytesPerUnit = 8
// ipv6ExtHdrLenBytesExcluded is the number of bytes excluded from an
// extension header's Length field following the Length field.
//
// The Length field excludes the first 8 bytes, but the Next Header and Length
// field take up the first 2 of the 8 bytes so we expect (at minimum) 6 bytes
// after the Length field.
//
// This ensures that every extension header is at least 8 bytes.
ipv6ExtHdrLenBytesExcluded = 6
// IPv6FragmentExtHdrFragmentOffsetBytesPerUnit is the unit size of a Fragment
// extension header's Fragment Offset field. That is, given a Fragment Offset
// of 2, the extension header is indicating that the fragment's payload
// starts at the 16th byte in the reassembled packet.
IPv6FragmentExtHdrFragmentOffsetBytesPerUnit = 8
)
// padIPv6OptionsLength returns the total length for IPv6 options of length l
// considering the 8-octet alignment as stated in RFC 8200 Section 4.2.
func padIPv6OptionsLength(length int) int {
return (length + ipv6ExtHdrLenBytesPerUnit - 1) & ^(ipv6ExtHdrLenBytesPerUnit - 1)
}
// padIPv6Option fills b with the appropriate padding options depending on its
// length.
func padIPv6Option(b []byte) {
switch len(b) {
case 0: // No padding needed.
case 1: // Pad with Pad1.
b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6Pad1ExtHdrOptionIdentifier)
default: // Pad with PadN.
s := b[ipv6ExtHdrOptionPayloadOffset:]
common.ClearArray(s)
b[ipv6ExtHdrOptionTypeOffset] = uint8(ipv6PadNExtHdrOptionIdentifier)
b[ipv6ExtHdrOptionLengthOffset] = uint8(len(s))
}
}
// ipv6OptionsAlignmentPadding returns the number of padding bytes needed to
// serialize an option at headerOffset with alignment requirements
// [align]n + alignOffset.
func ipv6OptionsAlignmentPadding(headerOffset int, align int, alignOffset int) int {
padLen := headerOffset - alignOffset
return ((padLen + align - 1) & ^(align - 1)) - padLen
}
// IPv6OptionUnknownAction is the action that must be taken if the processing
// IPv6 node does not recognize the option, as outlined in RFC 8200 section 4.2.
type IPv6OptionUnknownAction int
const (
// IPv6OptionUnknownActionSkip indicates that the unrecognized option must
// be skipped and the node should continue processing the header.
IPv6OptionUnknownActionSkip IPv6OptionUnknownAction = 0
// IPv6OptionUnknownActionDiscard indicates that the packet must be silently
// discarded.
IPv6OptionUnknownActionDiscard IPv6OptionUnknownAction = 1
// IPv6OptionUnknownActionDiscardSendICMP indicates that the packet must be
// discarded and the node must send an ICMP Parameter Problem, Code 2, message
// to the packet's source, regardless of whether or not the packet's
// Destination was a multicast address.
IPv6OptionUnknownActionDiscardSendICMP IPv6OptionUnknownAction = 2
// IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest indicates that the
// packet must be discarded and the node must send an ICMP Parameter Problem,
// Code 2, message to the packet's source only if the packet's Destination was
// not a multicast address.
IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest IPv6OptionUnknownAction = 3
)
// IPv6ExtHdrOption is implemented by the various IPv6 extension header options.
type IPv6ExtHdrOption interface {
// UnknownAction returns the action to take in response to an unrecognized
// option.
UnknownAction() IPv6OptionUnknownAction
// isIPv6ExtHdrOption is used to "lock" this interface so it is not
// implemented by other packages.
isIPv6ExtHdrOption()
}
// IPv6ExtHdrOptionIdentifier is an IPv6 extension header option identifier.
type IPv6ExtHdrOptionIdentifier uint8
const (
// ipv6Pad1ExtHdrOptionIdentifier is the identifier for a padding option that
// provides 1 byte padding, as outlined in RFC 8200 section 4.2.
ipv6Pad1ExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 0
// ipv6PadNExtHdrOptionIdentifier is the identifier for a padding option that
// provides variable length byte padding, as outlined in RFC 8200 section 4.2.
ipv6PadNExtHdrOptionIdentifier IPv6ExtHdrOptionIdentifier = 1
// ipv6RouterAlertHopByHopOptionIdentifier is the identifier for the Router
// Alert Hop by Hop option as defined in RFC 2711 section 2.1.
ipv6RouterAlertHopByHopOptionIdentifier IPv6ExtHdrOptionIdentifier = 5
// ipv6ExtHdrOptionTypeOffset is the option type offset in an extension header
// option as defined in RFC 8200 section 4.2.
ipv6ExtHdrOptionTypeOffset = 0
// ipv6ExtHdrOptionLengthOffset is the option length offset in an extension
// header option as defined in RFC 8200 section 4.2.
ipv6ExtHdrOptionLengthOffset = 1
// ipv6ExtHdrOptionPayloadOffset is the option payload offset in an extension
// header option as defined in RFC 8200 section 4.2.
ipv6ExtHdrOptionPayloadOffset = 2
)
// ipv6UnknownActionFromIdentifier maps an extension header option's
// identifier's high bits to the action to take when the identifier is unknown.
func ipv6UnknownActionFromIdentifier(id IPv6ExtHdrOptionIdentifier) IPv6OptionUnknownAction {
return IPv6OptionUnknownAction((id & ipv6UnknownExtHdrOptionActionMask) >> ipv6UnknownExtHdrOptionActionShift)
}
// ErrMalformedIPv6ExtHdrOption indicates that an IPv6 extension header option
// is malformed.
var ErrMalformedIPv6ExtHdrOption = errors.New("malformed IPv6 extension header option")
// IPv6FragmentExtHdr is a buffer holding the Fragment extension header specific
// data as outlined in RFC 8200 section 4.5.
//
// Note, the buffer does not include the Next Header and Reserved fields.
type IPv6FragmentExtHdr [6]byte
// isIPv6PayloadHeader implements IPv6PayloadHeader.isIPv6PayloadHeader.
func (IPv6FragmentExtHdr) isIPv6PayloadHeader() {}
// Release implements IPv6PayloadHeader.Release.
func (IPv6FragmentExtHdr) Release() {}
// FragmentOffset returns the Fragment Offset field.
//
// This value indicates where the buffer following the Fragment extension header
// starts in the target (reassembled) packet.
func (b IPv6FragmentExtHdr) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[ipv6FragmentExtHdrFragmentOffsetOffset:]) >> ipv6FragmentExtHdrFragmentOffsetShift
}
// More returns the More (M) flag.
//
// This indicates whether any fragments are expected to succeed b.
func (b IPv6FragmentExtHdr) More() bool {
return b[ipv6FragmentExtHdrFlagsIdx]&ipv6FragmentExtHdrMFlagMask != 0
}
// ID returns the Identification field.
//
// This value is used to uniquely identify the packet, between a
// source and destination.
func (b IPv6FragmentExtHdr) ID() uint32 {
return binary.BigEndian.Uint32(b[ipv6FragmentExtHdrIdentificationOffset:])
}
// IsAtomic returns whether the fragment header indicates an atomic fragment. An
// atomic fragment is a fragment that contains all the data required to
// reassemble a full packet.
func (b IPv6FragmentExtHdr) IsAtomic() bool {
return !b.More() && b.FragmentOffset() == 0
}
// IPv6SerializableExtHdr provides serialization for IPv6 extension
// headers.
type IPv6SerializableExtHdr interface {
// identifier returns the assigned IPv6 header identifier for this extension
// header.
identifier() IPv6ExtensionHeaderIdentifier
// length returns the total serialized length in bytes of this extension
// header, including the common next header and length fields.
length() int
// serializeInto serializes the receiver into the provided byte
// buffer and with the provided nextHeader value.
//
// Note, the caller MUST provide a byte buffer with size of at least
// length. Implementers of this function may assume that the byte buffer
// is of sufficient size. serializeInto MAY panic if the provided byte
// buffer is not of sufficient size.
//
// serializeInto returns the number of bytes that was used to serialize the
// receiver. Implementers must only use the number of bytes required to
// serialize the receiver. Callers MAY provide a larger buffer than required
// to serialize into.
serializeInto(nextHeader uint8, b []byte) int
}
var _ IPv6SerializableExtHdr = (*IPv6SerializableHopByHopExtHdr)(nil)
// IPv6SerializableHopByHopExtHdr implements serialization of the Hop by Hop
// options extension header.
type IPv6SerializableHopByHopExtHdr []IPv6SerializableHopByHopOption
const (
// ipv6HopByHopExtHdrNextHeaderOffset is the offset of the next header field
// in a hop by hop extension header as defined in RFC 8200 section 4.3.
ipv6HopByHopExtHdrNextHeaderOffset = 0
// ipv6HopByHopExtHdrLengthOffset is the offset of the length field in a hop
// by hop extension header as defined in RFC 8200 section 4.3.
ipv6HopByHopExtHdrLengthOffset = 1
// ipv6HopByHopExtHdrPayloadOffset is the offset of the options in a hop by
// hop extension header as defined in RFC 8200 section 4.3.
ipv6HopByHopExtHdrOptionsOffset = 2
// ipv6HopByHopExtHdrUnaccountedLenWords is the implicit number of 8-octet
// words in a hop by hop extension header's length field, as stated in RFC
// 8200 section 4.3:
// Length of the Hop-by-Hop Options header in 8-octet units,
// not including the first 8 octets.
ipv6HopByHopExtHdrUnaccountedLenWords = 1
)
// identifier implements IPv6SerializableExtHdr.
func (IPv6SerializableHopByHopExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
return IPv6HopByHopOptionsExtHdrIdentifier
}
// length implements IPv6SerializableExtHdr.
func (h IPv6SerializableHopByHopExtHdr) length() int {
var total int
for _, opt := range h {
align, alignOffset := opt.alignment()
total += ipv6OptionsAlignmentPadding(total, align, alignOffset)
total += ipv6ExtHdrOptionPayloadOffset + int(opt.length())
}
// Account for next header and total length fields and add padding.
return padIPv6OptionsLength(ipv6HopByHopExtHdrOptionsOffset + total)
}
// serializeInto implements IPv6SerializableExtHdr.
func (h IPv6SerializableHopByHopExtHdr) serializeInto(nextHeader uint8, b []byte) int {
optBuffer := b[ipv6HopByHopExtHdrOptionsOffset:]
totalLength := ipv6HopByHopExtHdrOptionsOffset
for _, opt := range h {
// Calculate alignment requirements and pad buffer if necessary.
align, alignOffset := opt.alignment()
padLen := ipv6OptionsAlignmentPadding(totalLength, align, alignOffset)
if padLen != 0 {
padIPv6Option(optBuffer[:padLen])
totalLength += padLen
optBuffer = optBuffer[padLen:]
}
l := opt.serializeInto(optBuffer[ipv6ExtHdrOptionPayloadOffset:])
optBuffer[ipv6ExtHdrOptionTypeOffset] = uint8(opt.identifier())
optBuffer[ipv6ExtHdrOptionLengthOffset] = l
l += ipv6ExtHdrOptionPayloadOffset
totalLength += int(l)
optBuffer = optBuffer[l:]
}
padded := padIPv6OptionsLength(totalLength)
if padded != totalLength {
padIPv6Option(optBuffer[:padded-totalLength])
totalLength = padded
}
wordsLen := totalLength/ipv6ExtHdrLenBytesPerUnit - ipv6HopByHopExtHdrUnaccountedLenWords
if wordsLen > math.MaxUint8 {
panic(fmt.Sprintf("IPv6 hop by hop options too large: %d+1 64-bit words", wordsLen))
}
b[ipv6HopByHopExtHdrNextHeaderOffset] = nextHeader
b[ipv6HopByHopExtHdrLengthOffset] = uint8(wordsLen)
return totalLength
}
// IPv6SerializableHopByHopOption provides serialization for hop by hop options.
type IPv6SerializableHopByHopOption interface {
// identifier returns the option identifier of this Hop by Hop option.
identifier() IPv6ExtHdrOptionIdentifier
// length returns the *payload* size of the option (not considering the type
// and length fields).
length() uint8
// alignment returns the alignment requirements from this option.
//
// Alignment requirements take the form [align]n + offset as specified in
// RFC 8200 section 4.2. The alignment requirement is on the offset between
// the option type byte and the start of the hop by hop header.
//
// align must be a power of 2.
alignment() (align int, offset int)
// serializeInto serializes the receiver into the provided byte
// buffer.
//
// Note, the caller MUST provide a byte buffer with size of at least
// length. Implementers of this function may assume that the byte buffer
// is of sufficient size. serializeInto MAY panic if the provided byte
// buffer is not of sufficient size.
//
// serializeInto will return the number of bytes that was used to
// serialize the receiver. Implementers must only use the number of
// bytes required to serialize the receiver. Callers MAY provide a
// larger buffer than required to serialize into.
serializeInto([]byte) uint8
}
var _ IPv6SerializableHopByHopOption = (*IPv6RouterAlertOption)(nil)
// IPv6RouterAlertOption is the IPv6 Router alert Hop by Hop option defined in
// RFC 2711 section 2.1.
type IPv6RouterAlertOption struct {
Value IPv6RouterAlertValue
}
// IPv6RouterAlertValue is the payload of an IPv6 Router Alert option.
type IPv6RouterAlertValue uint16
const (
// IPv6RouterAlertMLD indicates a datagram containing a Multicast Listener
// Discovery message as defined in RFC 2711 section 2.1.
IPv6RouterAlertMLD IPv6RouterAlertValue = 0
// IPv6RouterAlertRSVP indicates a datagram containing an RSVP message as
// defined in RFC 2711 section 2.1.
IPv6RouterAlertRSVP IPv6RouterAlertValue = 1
// IPv6RouterAlertActiveNetworks indicates a datagram containing an Active
// Networks message as defined in RFC 2711 section 2.1.
IPv6RouterAlertActiveNetworks IPv6RouterAlertValue = 2
// ipv6RouterAlertPayloadLength is the length of the Router Alert payload
// as defined in RFC 2711.
ipv6RouterAlertPayloadLength = 2
// ipv6RouterAlertAlignmentRequirement is the alignment requirement for the
// Router Alert option defined as 2n+0 in RFC 2711.
ipv6RouterAlertAlignmentRequirement = 2
// ipv6RouterAlertAlignmentOffsetRequirement is the alignment offset
// requirement for the Router Alert option defined as 2n+0 in RFC 2711 section
// 2.1.
ipv6RouterAlertAlignmentOffsetRequirement = 0
)
// UnknownAction implements IPv6ExtHdrOption.
func (*IPv6RouterAlertOption) UnknownAction() IPv6OptionUnknownAction {
return ipv6UnknownActionFromIdentifier(ipv6RouterAlertHopByHopOptionIdentifier)
}
// isIPv6ExtHdrOption implements IPv6ExtHdrOption.
func (*IPv6RouterAlertOption) isIPv6ExtHdrOption() {}
// identifier implements IPv6SerializableHopByHopOption.
func (*IPv6RouterAlertOption) identifier() IPv6ExtHdrOptionIdentifier {
return ipv6RouterAlertHopByHopOptionIdentifier
}
// length implements IPv6SerializableHopByHopOption.
func (*IPv6RouterAlertOption) length() uint8 {
return ipv6RouterAlertPayloadLength
}
// alignment implements IPv6SerializableHopByHopOption.
func (*IPv6RouterAlertOption) alignment() (int, int) {
// From RFC 2711 section 2.1:
// Alignment requirement: 2n+0.
return ipv6RouterAlertAlignmentRequirement, ipv6RouterAlertAlignmentOffsetRequirement
}
// serializeInto implements IPv6SerializableHopByHopOption.
func (o *IPv6RouterAlertOption) serializeInto(b []byte) uint8 {
binary.BigEndian.PutUint16(b, uint16(o.Value))
return ipv6RouterAlertPayloadLength
}
// IPv6ExtHdrSerializer provides serialization of IPv6 extension headers.
type IPv6ExtHdrSerializer []IPv6SerializableExtHdr
// Serialize serializes the provided list of IPv6 extension headers into b.
//
// Note, b must be of sufficient size to hold all the headers in s. See
// IPv6ExtHdrSerializer.Length for details on the getting the total size of a
// serialized IPv6ExtHdrSerializer.
//
// Serialize may panic if b is not of sufficient size to hold all the options
// in s.
//
// Serialize takes the transportProtocol value to be used as the last extension
// header's Next Header value and returns the header identifier of the first
// serialized extension header and the total serialized length.
func (s IPv6ExtHdrSerializer) Serialize(transportProtocol tcpip.TransportProtocolNumber, b []byte) (uint8, int) {
nextHeader := uint8(transportProtocol)
if len(s) == 0 {
return nextHeader, 0
}
var totalLength int
for i, h := range s[:len(s)-1] {
length := h.serializeInto(uint8(s[i+1].identifier()), b)
b = b[length:]
totalLength += length
}
totalLength += s[len(s)-1].serializeInto(nextHeader, b)
return uint8(s[0].identifier()), totalLength
}
// Length returns the total number of bytes required to serialize the extension
// headers.
func (s IPv6ExtHdrSerializer) Length() int {
var totalLength int
for _, h := range s {
totalLength += h.length()
}
return totalLength
}

View File

@@ -1,158 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"github.com/sagernet/sing-tun/internal/gtcpip"
)
const (
nextHdrFrag = 0
fragOff = 2
more = 3
idV6 = 4
)
var _ IPv6SerializableExtHdr = (*IPv6SerializableFragmentExtHdr)(nil)
// IPv6SerializableFragmentExtHdr is used to serialize an IPv6 fragment
// extension header as defined in RFC 8200 section 4.5.
type IPv6SerializableFragmentExtHdr struct {
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
// M is the "more" field of an IPv6 fragment.
M bool
// Identification is the "identification" field of an IPv6 fragment.
Identification uint32
}
// identifier implements IPv6SerializableFragmentExtHdr.
func (h *IPv6SerializableFragmentExtHdr) identifier() IPv6ExtensionHeaderIdentifier {
return IPv6FragmentHeader
}
// length implements IPv6SerializableFragmentExtHdr.
func (h *IPv6SerializableFragmentExtHdr) length() int {
return IPv6FragmentHeaderSize
}
// serializeInto implements IPv6SerializableFragmentExtHdr.
func (h *IPv6SerializableFragmentExtHdr) serializeInto(nextHeader uint8, b []byte) int {
// Prevent too many bounds checks.
_ = b[IPv6FragmentHeaderSize:]
binary.BigEndian.PutUint32(b[idV6:], h.Identification)
binary.BigEndian.PutUint16(b[fragOff:], h.FragmentOffset<<ipv6FragmentExtHdrFragmentOffsetShift)
b[nextHdrFrag] = nextHeader
if h.M {
b[more] |= ipv6FragmentExtHdrMFlagMask
}
return IPv6FragmentHeaderSize
}
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
type IPv6Fragment []byte
const (
// IPv6FragmentHeader header is the number used to specify that the next
// header is a fragment header, per RFC 2460.
IPv6FragmentHeader = 44
// IPv6FragmentHeaderSize is the size of the fragment header.
IPv6FragmentHeaderSize = 8
)
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
}
// NextHeader returns the value of the "next header" field of the ipv6 fragment.
func (b IPv6Fragment) NextHeader() uint8 {
return b[nextHdrFrag]
}
// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
func (b IPv6Fragment) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[fragOff:]) >> 3
}
// More returns the "more" field of the ipv6 fragment.
func (b IPv6Fragment) More() bool {
return b[more]&1 > 0
}
// Payload implements Network.Payload.
func (b IPv6Fragment) Payload() []byte {
return b[IPv6FragmentHeaderSize:]
}
// ID returns the value of the identifier field of the ipv6 fragment.
func (b IPv6Fragment) ID() uint32 {
return binary.BigEndian.Uint32(b[idV6:])
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.NextHeader())
}
// The functions below have been added only to satisfy the Network interface.
// Checksum is not supported by IPv6Fragment.
func (b IPv6Fragment) Checksum() uint16 {
panic("not supported")
}
// SourceAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SourceAddress() tcpip.Address {
panic("not supported")
}
// DestinationAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) DestinationAddress() tcpip.Address {
panic("not supported")
}
// SetSourceAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
panic("not supported")
}
// SetDestinationAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
panic("not supported")
}
// SetChecksum is not supported by IPv6Fragment.
func (b IPv6Fragment) SetChecksum(uint16) {
panic("not supported")
}
// TOS is not supported by IPv6Fragment.
func (b IPv6Fragment) TOS() (uint8, uint32) {
panic("not supported")
}
// SetTOS is not supported by IPv6Fragment.
func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
panic("not supported")
}

View File

@@ -1,110 +0,0 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import "github.com/sagernet/sing-tun/internal/gtcpip"
// NDPNeighborAdvert is an NDP Neighbor Advertisement message. It will
// only contain the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.4 for more details.
type NDPNeighborAdvert []byte
const (
// NDPNAMinimumSize is the minimum size of a valid NDP Neighbor
// Advertisement message (body of an ICMPv6 packet).
NDPNAMinimumSize = 20
// ndpNATargetAddressOffset is the start of the Target Address
// field within an NDPNeighborAdvert.
ndpNATargetAddressOffset = 4
// ndpNAOptionsOffset is the start of the NDP options in an
// NDPNeighborAdvert.
ndpNAOptionsOffset = ndpNATargetAddressOffset + IPv6AddressSize
// ndpNAFlagsOffset is the offset of the flags within an
// NDPNeighborAdvert
ndpNAFlagsOffset = 0
// ndpNARouterFlagMask is the mask of the Router Flag field in
// the flags byte within in an NDPNeighborAdvert.
ndpNARouterFlagMask = (1 << 7)
// ndpNASolicitedFlagMask is the mask of the Solicited Flag field in
// the flags byte within in an NDPNeighborAdvert.
ndpNASolicitedFlagMask = (1 << 6)
// ndpNAOverrideFlagMask is the mask of the Override Flag field in
// the flags byte within in an NDPNeighborAdvert.
ndpNAOverrideFlagMask = (1 << 5)
)
// TargetAddress returns the value within the Target Address field.
func (b NDPNeighborAdvert) TargetAddress() tcpip.Address {
return tcpip.AddrFrom16Slice(b[ndpNATargetAddressOffset:][:IPv6AddressSize])
}
// SetTargetAddress sets the value within the Target Address field.
func (b NDPNeighborAdvert) SetTargetAddress(addr tcpip.Address) {
copy(b[ndpNATargetAddressOffset:][:IPv6AddressSize], addr.AsSlice())
}
// RouterFlag returns the value of the Router Flag field.
func (b NDPNeighborAdvert) RouterFlag() bool {
return b[ndpNAFlagsOffset]&ndpNARouterFlagMask != 0
}
// SetRouterFlag sets the value in the Router Flag field.
func (b NDPNeighborAdvert) SetRouterFlag(f bool) {
if f {
b[ndpNAFlagsOffset] |= ndpNARouterFlagMask
} else {
b[ndpNAFlagsOffset] &^= ndpNARouterFlagMask
}
}
// SolicitedFlag returns the value of the Solicited Flag field.
func (b NDPNeighborAdvert) SolicitedFlag() bool {
return b[ndpNAFlagsOffset]&ndpNASolicitedFlagMask != 0
}
// SetSolicitedFlag sets the value in the Solicited Flag field.
func (b NDPNeighborAdvert) SetSolicitedFlag(f bool) {
if f {
b[ndpNAFlagsOffset] |= ndpNASolicitedFlagMask
} else {
b[ndpNAFlagsOffset] &^= ndpNASolicitedFlagMask
}
}
// OverrideFlag returns the value of the Override Flag field.
func (b NDPNeighborAdvert) OverrideFlag() bool {
return b[ndpNAFlagsOffset]&ndpNAOverrideFlagMask != 0
}
// SetOverrideFlag sets the value in the Override Flag field.
func (b NDPNeighborAdvert) SetOverrideFlag(f bool) {
if f {
b[ndpNAFlagsOffset] |= ndpNAOverrideFlagMask
} else {
b[ndpNAFlagsOffset] &^= ndpNAOverrideFlagMask
}
}
// Options returns an NDPOptions of the options body.
func (b NDPNeighborAdvert) Options() NDPOptions {
return NDPOptions(b[ndpNAOptionsOffset:])
}

View File

@@ -1,52 +0,0 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import "github.com/sagernet/sing-tun/internal/gtcpip"
// NDPNeighborSolicit is an NDP Neighbor Solicitation message. It will only
// contain the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.3 for more details.
type NDPNeighborSolicit []byte
const (
// NDPNSMinimumSize is the minimum size of a valid NDP Neighbor
// Solicitation message (body of an ICMPv6 packet).
NDPNSMinimumSize = 20
// ndpNSTargetAddessOffset is the start of the Target Address
// field within an NDPNeighborSolicit.
ndpNSTargetAddessOffset = 4
// ndpNSOptionsOffset is the start of the NDP options in an
// NDPNeighborSolicit.
ndpNSOptionsOffset = ndpNSTargetAddessOffset + IPv6AddressSize
)
// TargetAddress returns the value within the Target Address field.
func (b NDPNeighborSolicit) TargetAddress() tcpip.Address {
return tcpip.AddrFrom16Slice(b[ndpNSTargetAddessOffset:][:IPv6AddressSize])
}
// SetTargetAddress sets the value within the Target Address field.
func (b NDPNeighborSolicit) SetTargetAddress(addr tcpip.Address) {
copy(b[ndpNSTargetAddessOffset:][:IPv6AddressSize], addr.AsSlice())
}
// Options returns an NDPOptions of the options body.
func (b NDPNeighborSolicit) Options() NDPOptions {
return NDPOptions(b[ndpNSOptionsOffset:])
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,204 +0,0 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"fmt"
"time"
)
var _ fmt.Stringer = NDPRoutePreference(0)
// NDPRoutePreference is the preference values for default routers or
// more-specific routes.
//
// As per RFC 4191 section 2.1,
//
// Default router preferences and preferences for more-specific routes
// are encoded the same way.
//
// Preference values are encoded as a two-bit signed integer, as
// follows:
//
// 01 High
// 00 Medium (default)
// 11 Low
// 10 Reserved - MUST NOT be sent
//
// Note that implementations can treat the value as a two-bit signed
// integer.
//
// Having just three values reinforces that they are not metrics and
// more values do not appear to be necessary for reasonable scenarios.
type NDPRoutePreference uint8
const (
// HighRoutePreference indicates a high preference, as per
// RFC 4191 section 2.1.
HighRoutePreference NDPRoutePreference = 0b01
// MediumRoutePreference indicates a medium preference, as per
// RFC 4191 section 2.1.
//
// This is the default preference value.
MediumRoutePreference = 0b00
// LowRoutePreference indicates a low preference, as per
// RFC 4191 section 2.1.
LowRoutePreference = 0b11
// ReservedRoutePreference is a reserved preference value, as per
// RFC 4191 section 2.1.
//
// It MUST NOT be sent.
ReservedRoutePreference = 0b10
)
// String implements fmt.Stringer.
func (p NDPRoutePreference) String() string {
switch p {
case HighRoutePreference:
return "HighRoutePreference"
case MediumRoutePreference:
return "MediumRoutePreference"
case LowRoutePreference:
return "LowRoutePreference"
case ReservedRoutePreference:
return "ReservedRoutePreference"
default:
return fmt.Sprintf("NDPRoutePreference(%d)", p)
}
}
// NDPRouterAdvert is an NDP Router Advertisement message. It will only contain
// the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.2 and RFC 4191 section 2.2 for more details.
type NDPRouterAdvert []byte
// As per RFC 4191 section 2.2,
//
// 0 1 2 3
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Type | Code | Checksum |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Cur Hop Limit |M|O|H|Prf|Resvd| Router Lifetime |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Reachable Time |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Retrans Timer |
// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
// | Options ...
// +-+-+-+-+-+-+-+-+-+-+-+-
const (
// NDPRAMinimumSize is the minimum size of a valid NDP Router
// Advertisement message (body of an ICMPv6 packet).
NDPRAMinimumSize = 12
// ndpRACurrHopLimitOffset is the byte of the Curr Hop Limit field
// within an NDPRouterAdvert.
ndpRACurrHopLimitOffset = 0
// ndpRAFlagsOffset is the byte with the NDP RA bit-fields/flags
// within an NDPRouterAdvert.
ndpRAFlagsOffset = 1
// ndpRAManagedAddrConfFlagMask is the mask of the Managed Address
// Configuration flag within the bit-field/flags byte of an
// NDPRouterAdvert.
ndpRAManagedAddrConfFlagMask = (1 << 7)
// ndpRAOtherConfFlagMask is the mask of the Other Configuration flag
// within the bit-field/flags byte of an NDPRouterAdvert.
ndpRAOtherConfFlagMask = (1 << 6)
// ndpDefaultRouterPreferenceShift is the shift of the Prf (Default Router
// Preference) field within the flags byte of an NDPRouterAdvert.
ndpDefaultRouterPreferenceShift = 3
// ndpDefaultRouterPreferenceMask is the mask of the Prf (Default Router
// Preference) field within the flags byte of an NDPRouterAdvert.
ndpDefaultRouterPreferenceMask = (0b11 << ndpDefaultRouterPreferenceShift)
// ndpRARouterLifetimeOffset is the start of the 2-byte Router Lifetime
// field within an NDPRouterAdvert.
ndpRARouterLifetimeOffset = 2
// ndpRAReachableTimeOffset is the start of the 4-byte Reachable Time
// field within an NDPRouterAdvert.
ndpRAReachableTimeOffset = 4
// ndpRARetransTimerOffset is the start of the 4-byte Retrans Timer
// field within an NDPRouterAdvert.
ndpRARetransTimerOffset = 8
// ndpRAOptionsOffset is the start of the NDP options in an
// NDPRouterAdvert.
ndpRAOptionsOffset = 12
)
// CurrHopLimit returns the value of the Curr Hop Limit field.
func (b NDPRouterAdvert) CurrHopLimit() uint8 {
return b[ndpRACurrHopLimitOffset]
}
// ManagedAddrConfFlag returns the value of the Managed Address Configuration
// flag.
func (b NDPRouterAdvert) ManagedAddrConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAManagedAddrConfFlagMask != 0
}
// OtherConfFlag returns the value of the Other Configuration flag.
func (b NDPRouterAdvert) OtherConfFlag() bool {
return b[ndpRAFlagsOffset]&ndpRAOtherConfFlagMask != 0
}
// DefaultRouterPreference returns the Default Router Preference field.
func (b NDPRouterAdvert) DefaultRouterPreference() NDPRoutePreference {
return NDPRoutePreference((b[ndpRAFlagsOffset] & ndpDefaultRouterPreferenceMask) >> ndpDefaultRouterPreferenceShift)
}
// RouterLifetime returns the lifetime associated with the default router. A
// value of 0 means the source of the Router Advertisement is not a default
// router and SHOULD NOT appear on the default router list. Note, a value of 0
// only means that the router should not be used as a default router, it does
// not apply to other information contained in the Router Advertisement.
func (b NDPRouterAdvert) RouterLifetime() time.Duration {
// The field is the time in seconds, as per RFC 4861 section 4.2.
return time.Second * time.Duration(binary.BigEndian.Uint16(b[ndpRARouterLifetimeOffset:]))
}
// ReachableTime returns the time that a node assumes a neighbor is reachable
// after having received a reachability confirmation. A value of 0 means
// that it is unspecified by the source of the Router Advertisement message.
func (b NDPRouterAdvert) ReachableTime() time.Duration {
// The field is the time in milliseconds, as per RFC 4861 section 4.2.
return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRAReachableTimeOffset:]))
}
// RetransTimer returns the time between retransmitted Neighbor Solicitation
// messages. A value of 0 means that it is unspecified by the source of the
// Router Advertisement message.
func (b NDPRouterAdvert) RetransTimer() time.Duration {
// The field is the time in milliseconds, as per RFC 4861 section 4.2.
return time.Millisecond * time.Duration(binary.BigEndian.Uint32(b[ndpRARetransTimerOffset:]))
}
// Options returns an NDPOptions of the options body.
func (b NDPRouterAdvert) Options() NDPOptions {
return NDPOptions(b[ndpRAOptionsOffset:])
}

View File

@@ -1,36 +0,0 @@
// Copyright 2019 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
// NDPRouterSolicit is an NDP Router Solicitation message. It will only contain
// the body of an ICMPv6 packet.
//
// See RFC 4861 section 4.1 for more details.
type NDPRouterSolicit []byte
const (
// NDPRSMinimumSize is the minimum size of a valid NDP Router
// Solicitation message (body of an ICMPv6 packet).
NDPRSMinimumSize = 4
// ndpRSOptionsOffset is the start of the NDP options in an
// NDPRouterSolicit.
ndpRSOptionsOffset = 4
)
// Options returns an NDPOptions of the options body.
func (b NDPRouterSolicit) Options() NDPOptions {
return NDPOptions(b[ndpRSOptionsOffset:])
}

View File

@@ -1,58 +0,0 @@
// Copyright 2020 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type ndpOptionIdentifier"; DO NOT EDIT.
package header
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ndpSourceLinkLayerAddressOptionType-1]
_ = x[ndpTargetLinkLayerAddressOptionType-2]
_ = x[ndpPrefixInformationType-3]
_ = x[ndpNonceOptionType-14]
_ = x[ndpRecursiveDNSServerOptionType-25]
_ = x[ndpDNSSearchListOptionType-31]
}
const (
_ndpOptionIdentifier_name_0 = "ndpSourceLinkLayerAddressOptionTypendpTargetLinkLayerAddressOptionTypendpPrefixInformationType"
_ndpOptionIdentifier_name_1 = "ndpNonceOptionType"
_ndpOptionIdentifier_name_2 = "ndpRecursiveDNSServerOptionType"
_ndpOptionIdentifier_name_3 = "ndpDNSSearchListOptionType"
)
var (
_ndpOptionIdentifier_index_0 = [...]uint8{0, 35, 70, 94}
)
func (i ndpOptionIdentifier) String() string {
switch {
case 1 <= i && i <= 3:
i -= 1
return _ndpOptionIdentifier_name_0[_ndpOptionIdentifier_index_0[i]:_ndpOptionIdentifier_index_0[i+1]]
case i == 14:
return _ndpOptionIdentifier_name_1
case i == 25:
return _ndpOptionIdentifier_name_2
case i == 31:
return _ndpOptionIdentifier_name_3
default:
return "ndpOptionIdentifier(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View File

@@ -1,35 +0,0 @@
package header
import "net/netip"
func (b IPv4) SourceAddr() netip.Addr {
return netip.AddrFrom4([4]byte(b[srcAddr : srcAddr+IPv4AddressSize]))
}
func (b IPv4) DestinationAddr() netip.Addr {
return netip.AddrFrom4([4]byte(b[dstAddr : dstAddr+IPv4AddressSize]))
}
func (b IPv4) SetSourceAddr(addr netip.Addr) {
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr.AsSlice())
}
func (b IPv4) SetDestinationAddr(addr netip.Addr) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr.AsSlice())
}
func (b IPv6) SourceAddr() netip.Addr {
return netip.AddrFrom16([16]byte(b[v6SrcAddr:][:IPv6AddressSize]))
}
func (b IPv6) DestinationAddr() netip.Addr {
return netip.AddrFrom16([16]byte(b[v6DstAddr:][:IPv6AddressSize]))
}
func (b IPv6) SetSourceAddr(addr netip.Addr) {
copy(b[v6SrcAddr:][:IPv6AddressSize], addr.AsSlice())
}
func (b IPv6) SetDestinationAddr(addr netip.Addr) {
copy(b[v6DstAddr:][:IPv6AddressSize], addr.AsSlice())
}

View File

@@ -1,727 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/seqnum"
"github.com/google/btree"
)
// These constants are the offsets of the respective fields in the TCP header.
const (
TCPSrcPortOffset = 0
TCPDstPortOffset = 2
TCPSeqNumOffset = 4
TCPAckNumOffset = 8
TCPDataOffset = 12
TCPFlagsOffset = 13
TCPWinSizeOffset = 14
TCPChecksumOffset = 16
TCPUrgentPtrOffset = 18
)
const (
// MaxWndScale is maximum allowed window scaling, as described in
// RFC 1323, section 2.3, page 11.
MaxWndScale = 14
// TCPMaxSACKBlocks is the maximum number of SACK blocks that can
// be encoded in a TCP option field.
TCPMaxSACKBlocks = 4
)
// TCPFlags is the dedicated type for TCP flags.
type TCPFlags uint8
// Intersects returns true iff there are flags common to both f and o.
func (f TCPFlags) Intersects(o TCPFlags) bool {
return f&o != 0
}
// Contains returns true iff all the flags in o are contained within f.
func (f TCPFlags) Contains(o TCPFlags) bool {
return f&o == o
}
// String implements Stringer.String.
func (f TCPFlags) String() string {
flagsStr := []byte("FSRPAUEC")
for i := range flagsStr {
if f&(1<<uint(i)) == 0 {
flagsStr[i] = ' '
}
}
return string(flagsStr)
}
// Flags that may be set in a TCP segment.
const (
TCPFlagFin TCPFlags = 1 << iota
TCPFlagSyn
TCPFlagRst
TCPFlagPsh
TCPFlagAck
TCPFlagUrg
TCPFlagEce
TCPFlagCwr
)
// Options that may be present in a TCP segment.
const (
TCPOptionEOL = 0
TCPOptionNOP = 1
TCPOptionMSS = 2
TCPOptionWS = 3
TCPOptionTS = 8
TCPOptionSACKPermitted = 4
TCPOptionSACK = 5
)
// Option Lengths.
const (
TCPOptionMSSLength = 4
TCPOptionTSLength = 10
TCPOptionWSLength = 3
TCPOptionSackPermittedLength = 2
)
// TCPFields contains the fields of a TCP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type TCPFields struct {
// SrcPort is the "source port" field of a TCP packet.
SrcPort uint16
// DstPort is the "destination port" field of a TCP packet.
DstPort uint16
// SeqNum is the "sequence number" field of a TCP packet.
SeqNum uint32
// AckNum is the "acknowledgement number" field of a TCP packet.
AckNum uint32
// DataOffset is the "data offset" field of a TCP packet. It is the length of
// the TCP header in bytes.
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
Flags TCPFlags
// WindowSize is the "window size" field of a TCP packet.
WindowSize uint16
// Checksum is the "checksum" field of a TCP packet.
Checksum uint16
// UrgentPointer is the "urgent pointer" field of a TCP packet.
UrgentPointer uint16
}
// TCPSynOptions is used to return the parsed TCP Options in a syn
// segment.
//
// +stateify savable
type TCPSynOptions struct {
// MSS is the maximum segment size provided by the peer in the SYN.
MSS uint16
// WS is the window scale option provided by the peer in the SYN.
//
// Set to -1 if no window scale option was provided.
WS int
// TS is true if the timestamp option was provided in the syn/syn-ack.
TS bool
// TSVal is the value of the TSVal field in the timestamp option.
TSVal uint32
// TSEcr is the value of the TSEcr field in the timestamp option.
TSEcr uint32
// SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
SACKPermitted bool
// Flags if specified are set on the outgoing SYN. The SYN flag is
// always set.
Flags TCPFlags
}
// SACKBlock represents a single contiguous SACK block.
//
// +stateify savable
type SACKBlock struct {
// Start indicates the lowest sequence number in the block.
Start seqnum.Value
// End indicates the sequence number immediately following the last
// sequence number of this block.
End seqnum.Value
}
// Less returns true if r.Start < b.Start.
func (r SACKBlock) Less(b btree.Item) bool {
return r.Start.LessThan(b.(SACKBlock).Start)
}
// Contains returns true if b is completely contained in r.
func (r SACKBlock) Contains(b SACKBlock) bool {
return r.Start.LessThanEq(b.Start) && b.End.LessThanEq(r.End)
}
// TCPOptions are used to parse and cache the TCP segment options for a non
// syn/syn-ack segment.
//
// +stateify savable
type TCPOptions struct {
// TS is true if the TimeStamp option is enabled.
TS bool
// TSVal is the value in the TSVal field of the segment.
TSVal uint32
// TSEcr is the value in the TSEcr field of the segment.
TSEcr uint32
// SACKBlocks are the SACK blocks specified in the segment.
SACKBlocks []SACKBlock
}
// TCP represents a TCP header stored in a byte array.
type TCP []byte
const (
// TCPMinimumSize is the minimum size of a valid TCP packet.
TCPMinimumSize = 20
// TCPOptionsMaximumSize is the maximum size of TCP options.
TCPOptionsMaximumSize = 40
// TCPHeaderMaximumSize is the maximum header size of a TCP packet.
TCPHeaderMaximumSize = TCPMinimumSize + TCPOptionsMaximumSize
// TCPTotalHeaderMaximumSize is the maximum size of headers from all layers in
// a TCP packet. It analogous to MAX_TCP_HEADER in Linux.
//
// TODO(b/319936470): Investigate why this needs to be at least 140 bytes. In
// Linux this value is at least 160, but in theory we should be able to use
// 138. In practice anything less than 140 starts to break GSO on gVNIC
// hardware.
TCPTotalHeaderMaximumSize = 160
// TCPProtocolNumber is TCP's transport protocol number.
TCPProtocolNumber tcpip.TransportProtocolNumber = 6
// TCPMinimumMSS is the minimum acceptable value for MSS. This is the
// same as the value TCP_MIN_MSS defined net/tcp.h.
TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize
// TCPMinimumSendMSS is the minimum value for MSS in a sender. This is the
// same as the value TCP_MIN_SND_MSS in net/tcp.h.
TCPMinimumSendMSS = TCPOptionsMaximumSize + MinIPFragmentPayloadSize
// TCPMaximumMSS is the maximum acceptable value for MSS.
TCPMaximumMSS = 0xffff
// TCPDefaultMSS is the MSS value that should be used if an MSS option
// is not received from the peer. It's also the value returned by
// TCP_MAXSEG option for a socket in an unconnected state.
//
// Per RFC 1122, page 85: "If an MSS option is not received at
// connection setup, TCP MUST assume a default send MSS of 536."
TCPDefaultMSS = 536
)
// SourcePort returns the "source port" field of the TCP header.
func (b TCP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[TCPSrcPortOffset:])
}
// DestinationPort returns the "destination port" field of the TCP header.
func (b TCP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[TCPDstPortOffset:])
}
// SequenceNumber returns the "sequence number" field of the TCP header.
func (b TCP) SequenceNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPSeqNumOffset:])
}
// AckNumber returns the "ack number" field of the TCP header.
func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[TCPAckNumOffset:])
}
// DataOffset returns the "data offset" field of the TCP header. The return
// value is the length of the TCP header in bytes.
func (b TCP) DataOffset() uint8 {
return (b[TCPDataOffset] >> 4) * 4
}
// Payload returns the data in the TCP packet.
func (b TCP) Payload() []byte {
return b[b.DataOffset():]
}
// Flags returns the flags field of the TCP header.
func (b TCP) Flags() TCPFlags {
return TCPFlags(b[TCPFlagsOffset])
}
// WindowSize returns the "window size" field of the TCP header.
func (b TCP) WindowSize() uint16 {
return binary.BigEndian.Uint16(b[TCPWinSizeOffset:])
}
// Checksum returns the "checksum" field of the TCP header.
func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[TCPChecksumOffset:])
}
// UrgentPointer returns the "urgent pointer" field of the TCP header.
func (b TCP) UrgentPointer() uint16 {
return binary.BigEndian.Uint16(b[TCPUrgentPtrOffset:])
}
// SetSourcePort sets the "source port" field of the TCP header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[TCPSrcPortOffset:], port)
}
// SetDestinationPort sets the "destination port" field of the TCP header.
func (b TCP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[TCPDstPortOffset:], port)
}
// SetChecksum sets the checksum field of the TCP header.
func (b TCP) SetChecksum(xsum uint16) {
checksum.Put(b[TCPChecksumOffset:], xsum)
}
// SetDataOffset sets the data offset field of the TCP header. headerLen should
// be the length of the TCP header in bytes.
func (b TCP) SetDataOffset(headerLen uint8) {
b[TCPDataOffset] = (headerLen / 4) << 4
}
// SetSequenceNumber sets the sequence number field of the TCP header.
func (b TCP) SetSequenceNumber(seqNum uint32) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seqNum)
}
// SetAckNumber sets the ack number field of the TCP header.
func (b TCP) SetAckNumber(ackNum uint32) {
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ackNum)
}
// SetFlags sets the flags field of the TCP header.
func (b TCP) SetFlags(flags uint8) {
b[TCPFlagsOffset] = flags
}
// SetWindowSize sets the window size field of the TCP header.
func (b TCP) SetWindowSize(rcvwnd uint16) {
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
// SetUrgentPointer sets the window size field of the TCP header.
func (b TCP) SetUrgentPointer(urgentPointer uint16) {
binary.BigEndian.PutUint16(b[TCPUrgentPtrOffset:], urgentPointer)
}
// CalculateChecksum calculates the checksum of the TCP segment.
// partialChecksum is the checksum of the network-layer pseudo-header
// and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:b.DataOffset()], partialChecksum)
}
// IsChecksumValid returns true iff the TCP header's checksum is valid.
func (b TCP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum, payloadLength uint16) bool {
xsum := PseudoHeaderChecksum(TCPProtocolNumber, src.AsSlice(), dst.AsSlice(), uint16(b.DataOffset())+payloadLength)
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
}
// Options returns a slice that holds the unparsed TCP options in the segment.
func (b TCP) Options() []byte {
return b[TCPMinimumSize:b.DataOffset()]
}
// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
// option values in the TCP segment. NOTE: Invoking this function repeatedly is
// expensive as it reparses the options on each invocation.
func (b TCP) ParsedOptions() TCPOptions {
return ParseTCPOptions(b.Options())
}
func (b TCP) encodeSubset(seq, ack uint32, flags TCPFlags, rcvwnd uint16) {
binary.BigEndian.PutUint32(b[TCPSeqNumOffset:], seq)
binary.BigEndian.PutUint32(b[TCPAckNumOffset:], ack)
b[TCPFlagsOffset] = uint8(flags)
binary.BigEndian.PutUint16(b[TCPWinSizeOffset:], rcvwnd)
}
// Encode encodes all the fields of the TCP header.
func (b TCP) Encode(t *TCPFields) {
b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
b.SetSourcePort(t.SrcPort)
b.SetDestinationPort(t.DstPort)
b.SetDataOffset(t.DataOffset)
b.SetChecksum(t.Checksum)
b.SetUrgentPointer(t.UrgentPointer)
}
// EncodePartial updates a subset of the fields of the TCP header. It is useful
// in cases when similar segments are produced.
func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags TCPFlags, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
// We don't use the flags field directly from the header because it's a
// one-byte field with an odd offset, so it would be accounted for
// incorrectly by the Checksum routine.
tmp := make([]byte, 4)
binary.BigEndian.PutUint16(tmp, length)
binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
xsum := checksum.Checksum(tmp, partialChecksum)
// Encode the passed-in fields.
b.encodeSubset(seqnum, acknum, flags, rcvwnd)
// Add the contributions of the passed-in fields to the checksum.
xsum = checksum.Checksum(b[TCPSeqNumOffset:TCPSeqNumOffset+8], xsum)
xsum = checksum.Checksum(b[TCPWinSizeOffset:TCPWinSizeOffset+2], xsum)
// Encode the checksum.
b.SetChecksum(^xsum)
}
// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
func (b TCP) SetSourcePortWithChecksumUpdate(new uint16) {
old := b.SourcePort()
b.SetSourcePort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
func (b TCP) SetDestinationPortWithChecksumUpdate(new uint16) {
old := b.DestinationPort()
b.SetDestinationPort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
func (b TCP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
xsum := b.Checksum()
if fullChecksum {
xsum = ^xsum
}
xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
if fullChecksum {
xsum = ^xsum
}
b.SetChecksum(xsum)
}
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
limit := len(opts)
synOpts := TCPSynOptions{
// Per RFC 1122, page 85: "If an MSS option is not received at
// connection setup, TCP MUST assume a default send MSS of 536."
MSS: TCPDefaultMSS,
// If no window scale option is specified, WS in options is
// returned as -1; this is because the absence of the option
// indicates that the we cannot use window scaling on the
// receive end either.
WS: -1,
}
for i := 0; i < limit; {
switch opts[i] {
case TCPOptionEOL:
i = limit
case TCPOptionNOP:
i++
case TCPOptionMSS:
if i+4 > limit || opts[i+1] != 4 {
return synOpts
}
mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if mss == 0 {
return synOpts
}
synOpts.MSS = mss
if mss < TCPMinimumSendMSS {
synOpts.MSS = TCPMinimumSendMSS
}
i += 4
case TCPOptionWS:
if i+3 > limit || opts[i+1] != 3 {
return synOpts
}
ws := int(opts[i+2])
if ws > MaxWndScale {
ws = MaxWndScale
}
synOpts.WS = ws
i += 3
case TCPOptionTS:
if i+10 > limit || opts[i+1] != 10 {
return synOpts
}
synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
if isAck {
// If the segment is a SYN-ACK then store the Timestamp Echo Reply
// in the segment.
synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
}
synOpts.TS = true
i += 10
case TCPOptionSACKPermitted:
if i+2 > limit || opts[i+1] != 2 {
return synOpts
}
synOpts.SACKPermitted = true
i += 2
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return synOpts
}
l := int(opts[i+1])
// If the length is incorrect or if l+i overflows the
// total options length then return false.
if l < 2 || i+l > limit {
return synOpts
}
i += l
}
}
return synOpts
}
// ParseTCPOptions extracts and stores all known options in the provided byte
// slice in a TCPOptions structure.
func ParseTCPOptions(b []byte) TCPOptions {
opts := TCPOptions{}
limit := len(b)
for i := 0; i < limit; {
switch b[i] {
case TCPOptionEOL:
i = limit
case TCPOptionNOP:
i++
case TCPOptionTS:
if i+10 > limit || (b[i+1] != 10) {
return opts
}
opts.TS = true
opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
i += 10
case TCPOptionSACK:
if i+2 > limit {
// Malformed SACK block, just return and stop parsing.
return opts
}
sackOptionLen := int(b[i+1])
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
// Malformed SACK block, just return and stop parsing.
return opts
}
numBlocks := (sackOptionLen - 2) / 8
opts.SACKBlocks = []SACKBlock{}
for j := 0; j < numBlocks; j++ {
start := binary.BigEndian.Uint32(b[i+2+j*8:])
end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
Start: seqnum.Value(start),
End: seqnum.Value(end),
})
}
i += sackOptionLen
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return opts
}
l := int(b[i+1])
// If the length is incorrect or if l+i overflows the
// total options length then return false.
if l < 2 || i+l > limit {
return opts
}
i += l
}
}
return opts
}
// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
// the supplied buffer. If the provided buffer is not large enough then it just
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeMSSOption(mss uint32, b []byte) int {
if len(b) < TCPOptionMSSLength {
return 0
}
b[0], b[1], b[2], b[3] = TCPOptionMSS, TCPOptionMSSLength, byte(mss>>8), byte(mss)
return TCPOptionMSSLength
}
// EncodeWSOption encodes the WS TCP option with the WS value in the
// provided buffer. If the provided buffer is not large enough then it just
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeWSOption(ws int, b []byte) int {
if len(b) < TCPOptionWSLength {
return 0
}
b[0], b[1], b[2] = TCPOptionWS, TCPOptionWSLength, uint8(ws)
return int(b[1])
}
// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
// option into the provided buffer. If the buffer is smaller than expected it
// just returns without encoding anything. It returns the number of bytes
// written to the provided buffer.
func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
if len(b) < TCPOptionTSLength {
return 0
}
b[0], b[1] = TCPOptionTS, TCPOptionTSLength
binary.BigEndian.PutUint32(b[2:], tsVal)
binary.BigEndian.PutUint32(b[6:], tsEcr)
return int(b[1])
}
// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
// buffer. If the buffer is smaller than required it just returns without
// encoding anything. It returns the number of bytes written to the provided
// buffer.
func EncodeSACKPermittedOption(b []byte) int {
if len(b) < TCPOptionSackPermittedLength {
return 0
}
b[0], b[1] = TCPOptionSACKPermitted, TCPOptionSackPermittedLength
return int(b[1])
}
// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
// in the provided slice. It tries to fit in as many blocks as possible based on
// number of bytes available in the provided buffer. It returns the number of
// bytes written to the provided buffer.
func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
if len(sackBlocks) == 0 {
return 0
}
l := len(sackBlocks)
if l > TCPMaxSACKBlocks {
l = TCPMaxSACKBlocks
}
if ll := (len(b) - 2) / 8; ll < l {
l = ll
}
if l == 0 {
// There is not enough space in the provided buffer to add
// any SACK blocks.
return 0
}
b[0] = TCPOptionSACK
b[1] = byte(l*8 + 2)
for i := 0; i < l; i++ {
binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
}
return int(b[1])
}
// EncodeNOP adds an explicit NOP to the option list.
func EncodeNOP(b []byte) int {
if len(b) == 0 {
return 0
}
b[0] = TCPOptionNOP
return 1
}
// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
// the option buffer. It adds padding bytes after the offset specified and
// returns the number of padding bytes added. The passed in options slice
// must have space for the padding bytes.
func AddTCPOptionPadding(options []byte, offset int) int {
paddingToAdd := -offset & 3
// Now add any padding bytes that might be required to quad align the
// options.
for i := offset; i < offset+paddingToAdd; i++ {
options[i] = TCPOptionNOP
}
return paddingToAdd
}
// Acceptable checks if a segment that starts at segSeq and has length segLen is
// "acceptable" for arriving in a receive window that starts at rcvNxt and ends
// before rcvAcc, according to the table on page 26 and 69 of RFC 793.
func Acceptable(segSeq seqnum.Value, segLen seqnum.Size, rcvNxt, rcvAcc seqnum.Value) bool {
if rcvNxt == rcvAcc {
return segLen == 0 && segSeq == rcvNxt
}
if segLen == 0 {
// rcvWnd is incremented by 1 because that is Linux's behavior despite the
// RFC.
return segSeq.InRange(rcvNxt, rcvAcc.Add(1))
}
// Page 70 of RFC 793 allows packets that can be made "acceptable" by trimming
// the payload, so we'll accept any payload that overlaps the receive window.
// segSeq < rcvAcc is more correct according to RFC, however, Linux does it
// differently, it uses segSeq <= rcvAcc, we'd want to keep the same behavior
// as Linux.
return rcvNxt.LessThan(segSeq.Add(segLen)) && segSeq.LessThanEq(rcvAcc)
}
// TCPValid returns true if the pkt has a valid TCP header. It checks whether:
// - The data offset is too small.
// - The data offset is too large.
// - The checksum is invalid.
//
// TCPValid corresponds to net/netfilter/nf_conntrack_proto_tcp.c:tcp_error.
func TCPValid(hdr TCP, payloadChecksum func() uint16, payloadSize uint16, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (csum uint16, csumValid, ok bool) {
if offset := int(hdr.DataOffset()); offset < TCPMinimumSize || offset > len(hdr) {
return
}
if skipChecksumValidation {
csumValid = true
} else {
csum = hdr.Checksum()
csumValid = hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum(), payloadSize)
}
return csum, csumValid, true
}

View File

@@ -1,195 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"math"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
)
const (
udpSrcPort = 0
udpDstPort = 2
udpLength = 4
udpChecksum = 6
)
const (
// UDPMaximumPacketSize is the largest possible UDP packet.
UDPMaximumPacketSize = 0xffff
)
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
type UDPFields struct {
// SrcPort is the "source port" field of a UDP packet.
SrcPort uint16
// DstPort is the "destination port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
Length uint16
// Checksum is the "checksum" field of a UDP packet.
Checksum uint16
}
// UDP represents a UDP header stored in a byte array.
type UDP []byte
const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
// UDPMaximumSize is the maximum size of a valid UDP packet. The length field
// in the UDP header is 16 bits as per RFC 768.
UDPMaximumSize = math.MaxUint16
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
// SourcePort returns the "source port" field of the UDP header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
// DestinationPort returns the "destination port" field of the UDP header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
// Length returns the "length" field of the UDP header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
// Checksum returns the "checksum" field of the UDP header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
// SetSourcePort sets the "source port" field of the UDP header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
// SetDestinationPort sets the "destination port" field of the UDP header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
// SetChecksum sets the "checksum" field of the UDP header.
func (b UDP) SetChecksum(xsum uint16) {
checksum.Put(b[udpChecksum:], xsum)
}
// SetLength sets the "length" field of the UDP header.
func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length)
}
// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum.
return checksum.Checksum(b[:UDPMinimumSize], partialChecksum)
}
// IsChecksumValid returns true iff the UDP header's checksum is valid.
func (b UDP) IsChecksumValid(src, dst tcpip.Address, payloadChecksum uint16) bool {
xsum := PseudoHeaderChecksum(UDPProtocolNumber, dst.AsSlice(), src.AsSlice(), b.Length())
xsum = checksum.Combine(xsum, payloadChecksum)
return b.CalculateChecksum(xsum) == 0xffff
}
// Encode encodes all the fields of the UDP header.
func (b UDP) Encode(u *UDPFields) {
b.SetSourcePort(u.SrcPort)
b.SetDestinationPort(u.DstPort)
b.SetLength(u.Length)
b.SetChecksum(u.Checksum)
}
// SetSourcePortWithChecksumUpdate implements ChecksummableTransport.
func (b UDP) SetSourcePortWithChecksumUpdate(new uint16) {
old := b.SourcePort()
b.SetSourcePort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// SetDestinationPortWithChecksumUpdate implements ChecksummableTransport.
func (b UDP) SetDestinationPortWithChecksumUpdate(new uint16) {
old := b.DestinationPort()
b.SetDestinationPort(new)
b.SetChecksum(^checksumUpdate2ByteAlignedUint16(^b.Checksum(), old, new))
}
// UpdateChecksumPseudoHeaderAddress implements ChecksummableTransport.
func (b UDP) UpdateChecksumPseudoHeaderAddress(old, new tcpip.Address, fullChecksum bool) {
xsum := b.Checksum()
if fullChecksum {
xsum = ^xsum
}
xsum = checksumUpdate2ByteAlignedAddress(xsum, old, new)
if fullChecksum {
xsum = ^xsum
}
b.SetChecksum(xsum)
}
// UDPValid returns true if the pkt has a valid UDP header. It checks whether:
// - The length field is too small.
// - The length field is too large.
// - The checksum is invalid.
//
// UDPValid corresponds to net/netfilter/nf_conntrack_proto_udp.c:udp_error.
func UDPValid(hdr UDP, payloadChecksum func() uint16, payloadSize uint16, netProto tcpip.NetworkProtocolNumber, srcAddr, dstAddr tcpip.Address, skipChecksumValidation bool) (lengthValid, csumValid bool) {
if length := hdr.Length(); length > payloadSize+UDPMinimumSize || length < UDPMinimumSize {
return false, false
}
if skipChecksumValidation {
return true, true
}
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
// omitted the checksum generation, as per RFC 768:
//
// An all zero transmitted checksum value means that the transmitter
// generated no checksum (for debugging or for higher level protocols that
// don't care).
//
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
//
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
// checksum is not optional.
if netProto == IPv4ProtocolNumber && hdr.Checksum() == 0 {
return true, true
}
return true, hdr.IsChecksumValid(srcAddr, dstAddr, payloadChecksum())
}

View File

@@ -1,62 +0,0 @@
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package seqnum defines the types and methods for TCP sequence numbers such
// that they fit in 32-bit words and work properly when overflows occur.
package seqnum
// Value represents the value of a sequence number.
type Value uint32
// Size represents the size (length) of a sequence number window.
type Size uint32
// LessThan checks if v is before w, i.e., v < w.
func (v Value) LessThan(w Value) bool {
return int32(v-w) < 0
}
// LessThanEq returns true if v==w or v is before i.e., v < w.
func (v Value) LessThanEq(w Value) bool {
if v == w {
return true
}
return v.LessThan(w)
}
// InRange checks if v is in the range [a,b), i.e., a <= v < b.
func (v Value) InRange(a, b Value) bool {
return v-a < b-a
}
// InWindow checks if v is in the window that starts at 'first' and spans 'size'
// sequence numbers.
func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
}
// Size calculates the size of the window defined by [v, w).
func (v Value) Size(w Value) Size {
return Size(w - v)
}
// UpdateForward updates v such that it becomes v + s.
func (v *Value) UpdateForward(s Size) {
*v += Value(s)
}

View File

@@ -1,573 +0,0 @@
// Copyright 2024 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcpip
import (
"crypto/rand"
"errors"
"fmt"
"math"
"math/bits"
"net"
"strconv"
"strings"
"time"
)
// Using the header package here would cause an import cycle.
const (
ipv4AddressSize = 4
ipv4ProtocolNumber = 0x0800
ipv6AddressSize = 16
ipv6ProtocolNumber = 0x86dd
)
const (
// LinkAddressSize is the size of a MAC address.
LinkAddressSize = 6
)
// Known IP address.
var (
IPv4Zero = []byte{0, 0, 0, 0}
IPv6Zero = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
)
// Errors related to Subnet
var (
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
)
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
// NetworkProtocolNumber is the EtherType of a network protocol in an Ethernet
// frame.
//
// See: https://www.iana.org/assignments/ieee-802-numbers/ieee-802-numbers.xhtml
type NetworkProtocolNumber uint32
// MonotonicTime is a monotonic clock reading.
//
// +stateify savable
type MonotonicTime struct {
nanoseconds int64
}
// String implements Stringer.
func (mt MonotonicTime) String() string {
return strconv.FormatInt(mt.nanoseconds, 10)
}
// MonotonicTimeInfinite returns the monotonic timestamp as far away in the
// future as possible.
func MonotonicTimeInfinite() MonotonicTime {
return MonotonicTime{nanoseconds: math.MaxInt64}
}
// Before reports whether the monotonic clock reading mt is before u.
func (mt MonotonicTime) Before(u MonotonicTime) bool {
return mt.nanoseconds < u.nanoseconds
}
// After reports whether the monotonic clock reading mt is after u.
func (mt MonotonicTime) After(u MonotonicTime) bool {
return mt.nanoseconds > u.nanoseconds
}
// Add returns the monotonic clock reading mt+d.
func (mt MonotonicTime) Add(d time.Duration) MonotonicTime {
return MonotonicTime{
nanoseconds: time.Unix(0, mt.nanoseconds).Add(d).Sub(time.Unix(0, 0)).Nanoseconds(),
}
}
// Sub returns the duration mt-u. If the result exceeds the maximum (or minimum)
// value that can be stored in a Duration, the maximum (or minimum) duration
// will be returned. To compute t-d for a duration d, use t.Add(-d).
func (mt MonotonicTime) Sub(u MonotonicTime) time.Duration {
return time.Unix(0, mt.nanoseconds).Sub(time.Unix(0, u.nanoseconds))
}
// Milliseconds returns the time in milliseconds.
func (mt MonotonicTime) Milliseconds() int64 {
return mt.nanoseconds / 1e6
}
// A Clock provides the current time and schedules work for execution.
//
// Times returned by a Clock should always be used for application-visible
// time. Only monotonic times should be used for netstack internal timekeeping.
type Clock interface {
// Now returns the current local time.
Now() time.Time
// NowMonotonic returns the current monotonic clock reading.
NowMonotonic() MonotonicTime
// AfterFunc waits for the duration to elapse and then calls f in its own
// goroutine. It returns a Timer that can be used to cancel the call using
// its Stop method.
AfterFunc(d time.Duration, f func()) Timer
}
// Timer represents a single event. A Timer must be created with
// Clock.AfterFunc.
type Timer interface {
// Stop prevents the Timer from firing. It returns true if the call stops the
// timer, false if the timer has already expired or been stopped.
//
// If Stop returns false, then the timer has already expired and the function
// f of Clock.AfterFunc(d, f) has been started in its own goroutine; Stop
// does not wait for f to complete before returning. If the caller needs to
// know whether f is completed, it must coordinate with f explicitly.
Stop() bool
// Reset changes the timer to expire after duration d.
//
// Reset should be invoked only on stopped or expired timers. If the timer is
// known to have expired, Reset can be used directly. Otherwise, the caller
// must coordinate with the function f of Clock.AfterFunc(d, f).
Reset(d time.Duration)
}
// Address is a byte slice cast as a string that represents the address of a
// network node. Or, in the case of unix endpoints, it may represent a path.
//
// +stateify savable
type Address struct {
addr [16]byte
length int
}
// AddrFrom4 converts addr to an Address.
func AddrFrom4(addr [4]byte) Address {
ret := Address{
length: 4,
}
// It's guaranteed that copy will return 4.
copy(ret.addr[:], addr[:])
return ret
}
// AddrFrom4Slice converts addr to an Address. It panics if len(addr) != 4.
func AddrFrom4Slice(addr []byte) Address {
if len(addr) != 4 {
panic(fmt.Sprintf("bad address length for address %v", addr))
}
ret := Address{
length: 4,
}
// It's guaranteed that copy will return 4.
copy(ret.addr[:], addr)
return ret
}
// AddrFrom16 converts addr to an Address.
func AddrFrom16(addr [16]byte) Address {
ret := Address{
length: 16,
}
// It's guaranteed that copy will return 16.
copy(ret.addr[:], addr[:])
return ret
}
// AddrFrom16Slice converts addr to an Address. It panics if len(addr) != 16.
func AddrFrom16Slice(addr []byte) Address {
if len(addr) != 16 {
panic(fmt.Sprintf("bad address length for address %v", addr))
}
ret := Address{
length: 16,
}
// It's guaranteed that copy will return 16.
copy(ret.addr[:], addr)
return ret
}
// AddrFromSlice converts addr to an Address. It returns the Address zero value
// if len(addr) != 4 or 16.
func AddrFromSlice(addr []byte) Address {
switch len(addr) {
case ipv4AddressSize:
return AddrFrom4Slice(addr)
case ipv6AddressSize:
return AddrFrom16Slice(addr)
}
return Address{}
}
// As4 returns a as a 4 byte array. It panics if the address length is not 4.
func (a Address) As4() [4]byte {
if a.Len() != 4 {
panic(fmt.Sprintf("bad address length for address %v", a.addr))
}
return [4]byte(a.addr[:4])
}
// As16 returns a as a 16 byte array. It panics if the address length is not 16.
func (a Address) As16() [16]byte {
if a.Len() != 16 {
panic(fmt.Sprintf("bad address length for address %v", a.addr))
}
return [16]byte(a.addr[:16])
}
// AsSlice returns a as a byte slice. Callers should be careful as it can
// return a window into existing memory.
//
// +checkescape
func (a *Address) AsSlice() []byte {
return a.addr[:a.length]
}
// BitLen returns the length in bits of a.
func (a Address) BitLen() int {
return a.Len() * 8
}
// Len returns the length in bytes of a.
func (a Address) Len() int {
return a.length
}
// WithPrefix returns the address with a prefix that represents a point subnet.
func (a Address) WithPrefix() AddressWithPrefix {
return AddressWithPrefix{
Address: a,
PrefixLen: a.BitLen(),
}
}
// Unspecified returns true if the address is unspecified.
func (a Address) Unspecified() bool {
for _, b := range a.addr {
if b != 0 {
return false
}
}
return true
}
// Equal returns whether a and other are equal. It exists for use by the cmp
// library.
func (a Address) Equal(other Address) bool {
return a == other
}
// MatchingPrefix returns the matching prefix length in bits.
//
// Panics if b and a have different lengths.
func (a Address) MatchingPrefix(b Address) uint8 {
const bitsInAByte = 8
if a.Len() != b.Len() {
panic(fmt.Sprintf("addresses %s and %s do not have the same length", a, b))
}
var prefix uint8
for i := 0; i < a.length; i++ {
aByte := a.addr[i]
bByte := b.addr[i]
if aByte == bByte {
prefix += bitsInAByte
continue
}
// Count the remaining matching bits in the byte from MSbit to LSBbit.
mask := uint8(1) << (bitsInAByte - 1)
for {
if aByte&mask == bByte&mask {
prefix++
mask >>= 1
continue
}
break
}
break
}
return prefix
}
// AddressMask is a bitmask for an address.
//
// +stateify savable
type AddressMask struct {
mask [16]byte
length int
}
// MaskFrom returns a Mask based on str.
//
// MaskFrom may allocate, and so should not be in hot paths.
func MaskFrom(str string) AddressMask {
mask := AddressMask{length: len(str)}
copy(mask.mask[:], str)
return mask
}
// MaskFromBytes returns a Mask based on bs.
func MaskFromBytes(bs []byte) AddressMask {
mask := AddressMask{length: len(bs)}
copy(mask.mask[:], bs)
return mask
}
// String implements Stringer.
func (m AddressMask) String() string {
return fmt.Sprintf("%x", m.mask)
}
// AsSlice returns a as a byte slice. Callers should be careful as it can
// return a window into existing memory.
func (m *AddressMask) AsSlice() []byte {
return []byte(m.mask[:m.length])
}
// BitLen returns the length of the mask in bits.
func (m AddressMask) BitLen() int {
return m.length * 8
}
// Len returns the length of the mask in bytes.
func (m AddressMask) Len() int {
return m.length
}
// Prefix returns the number of bits before the first host bit.
func (m AddressMask) Prefix() int {
p := 0
for _, b := range m.mask[:m.length] {
p += bits.LeadingZeros8(^b)
}
return p
}
// Equal returns whether m and other are equal. It exists for use by the cmp
// library.
func (m AddressMask) Equal(other AddressMask) bool {
return m == other
}
// Subnet is a subnet defined by its address and mask.
//
// +stateify savable
type Subnet struct {
address Address
mask AddressMask
}
// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
func NewSubnet(a Address, m AddressMask) (Subnet, error) {
if a.Len() != m.Len() {
return Subnet{}, errSubnetLengthMismatch
}
for i := 0; i < a.Len(); i++ {
if a.addr[i]&^m.mask[i] != 0 {
return Subnet{}, errSubnetAddressMasked
}
}
return Subnet{a, m}, nil
}
// String implements Stringer.
func (s Subnet) String() string {
return fmt.Sprintf("%s/%d", s.ID(), s.Prefix())
}
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {
if a.Len() != s.address.Len() {
return false
}
for i := 0; i < a.Len(); i++ {
if a.addr[i]&s.mask.mask[i] != s.address.addr[i] {
return false
}
}
return true
}
// ID returns the subnet ID.
func (s *Subnet) ID() Address {
return s.address
}
// Bits returns the number of ones (network bits) and zeros (host bits) in the
// subnet mask.
func (s *Subnet) Bits() (ones int, zeros int) {
ones = s.mask.Prefix()
return ones, s.mask.BitLen() - ones
}
// Prefix returns the number of bits before the first host bit.
func (s *Subnet) Prefix() int {
return s.mask.Prefix()
}
// Mask returns the subnet mask.
func (s *Subnet) Mask() AddressMask {
return s.mask
}
// Broadcast returns the subnet's broadcast address.
func (s *Subnet) Broadcast() Address {
addrCopy := s.address
for i := 0; i < addrCopy.Len(); i++ {
addrCopy.addr[i] |= ^s.mask.mask[i]
}
return addrCopy
}
// IsBroadcast returns true if the address is considered a broadcast address.
func (s *Subnet) IsBroadcast(address Address) bool {
// Only IPv4 supports the notion of a broadcast address.
if address.Len() != ipv4AddressSize {
return false
}
// Normally, we would just compare address with the subnet's broadcast
// address but there is an exception where a simple comparison is not
// correct. This exception is for /31 and /32 IPv4 subnets where all
// addresses are considered valid host addresses.
//
// For /31 subnets, the case is easy. RFC 3021 Section 2.1 states that
// both addresses in a /31 subnet "MUST be interpreted as host addresses."
//
// For /32, the case is a bit more vague. RFC 3021 makes no mention of /32
// subnets. However, the same reasoning applies - if an exception is not
// made, then there do not exist any host addresses in a /32 subnet. RFC
// 4632 Section 3.1 also vaguely implies this interpretation by referring
// to addresses in /32 subnets as "host routes."
return s.Prefix() <= 30 && s.Broadcast() == address
}
// Equal returns true if this Subnet is equal to the given Subnet.
func (s Subnet) Equal(o Subnet) bool {
// If this changes, update Route.Equal accordingly.
return s == o
}
// LinkAddress is a byte slice cast as a string that represents a link address.
// It is typically a 6-byte MAC address.
type LinkAddress string
// String implements the fmt.Stringer interface.
func (a LinkAddress) String() string {
switch len(a) {
case 6:
return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
default:
return fmt.Sprintf("%x", []byte(a))
}
}
// ParseMACAddress parses an IEEE 802 address.
//
// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
func ParseMACAddress(s string) (LinkAddress, error) {
parts := strings.FieldsFunc(s, func(c rune) bool {
return c == ':' || c == '-'
})
if len(parts) != LinkAddressSize {
return "", fmt.Errorf("inconsistent parts: %s", s)
}
addr := make([]byte, 0, len(parts))
for _, part := range parts {
u, err := strconv.ParseUint(part, 16, 8)
if err != nil {
return "", fmt.Errorf("invalid hex digits: %s", s)
}
addr = append(addr, byte(u))
}
return LinkAddress(addr), nil
}
// GetRandMacAddr returns a mac address that can be used for local virtual devices.
func GetRandMacAddr() LinkAddress {
mac := make(net.HardwareAddr, LinkAddressSize)
rand.Read(mac) // Fill with random data.
mac[0] &^= 0x1 // Clear multicast bit.
mac[0] |= 0x2 // Set local assignment bit (IEEE802).
return LinkAddress(mac)
}
// AddressWithPrefix is an address with its subnet prefix length.
//
// +stateify savable
type AddressWithPrefix struct {
// Address is a network address.
Address Address
// PrefixLen is the subnet prefix length.
PrefixLen int
}
// String implements the fmt.Stringer interface.
func (a AddressWithPrefix) String() string {
return fmt.Sprintf("%s/%d", a.Address, a.PrefixLen)
}
// Subnet converts the address and prefix into a Subnet value and returns it.
func (a AddressWithPrefix) Subnet() Subnet {
addrLen := a.Address.length
if a.PrefixLen <= 0 {
return Subnet{
address: Address{length: addrLen},
mask: AddressMask{length: addrLen},
}
}
if a.PrefixLen >= addrLen*8 {
sub := Subnet{
address: a.Address,
mask: AddressMask{length: addrLen},
}
for i := 0; i < addrLen; i++ {
sub.mask.mask[i] = 0xff
}
return sub
}
sa := Address{length: addrLen}
sm := AddressMask{length: addrLen}
n := uint(a.PrefixLen)
for i := 0; i < addrLen; i++ {
if n >= 8 {
sa.addr[i] = a.Address.addr[i]
sm.mask[i] = 0xff
n -= 8
continue
}
sm.mask[i] = ^byte(0xff >> n)
sa.addr[i] = a.Address.addr[i] & sm.mask[i]
n = 0
}
// For extra caution, call NewSubnet rather than directly creating the Subnet
// value. If that fails it indicates a serious bug in this code, so panic is
// in order.
s, err := NewSubnet(sa, sm)
if err != nil {
panic("invalid subnet: " + err.Error())
}
return s
}

View File

@@ -1,712 +0,0 @@
package tschecksum
import (
"encoding/binary"
"math/bits"
"strconv"
"golang.org/x/sys/cpu"
)
// checksumGeneric64 is a reference implementation of checksum using 64 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric64(b []byte, initial uint16) uint16 {
var ac uint64
var carry uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 128 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[120:128]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[64:72]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[72:80]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[80:88]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[88:96]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[96:104]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[104:112]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[112:120]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[120:128]), carry)
}
b = b[128:]
}
if len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[56:64]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[32:40]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[40:48]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[48:56]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[56:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[24:32]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[16:24]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[24:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b[8:16]), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[:8]), carry)
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b[8:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, binary.BigEndian.Uint64(b), carry)
} else {
ac, carry = bits.Add64(ac, binary.LittleEndian.Uint64(b), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint32(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint32(b)), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add64(ac, uint64(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add64(ac, uint64(b[0])<<8, carry)
} else {
ac, carry = bits.Add64(ac, uint64(b[0]), carry)
}
}
folded := ipChecksumFold64(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32 is a reference implementation of checksum using 32 bit
// arithmetic for use in testing or when an architecture-specific implementation
// is not available.
func checksumGeneric32(b []byte, initial uint16) uint16 {
var ac uint32
var carry uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[60:64]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[32:36]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[36:40]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[40:44]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[44:48]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[48:52]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[52:56]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[56:60]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[60:64]), carry)
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[28:32]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[16:20]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[20:24]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[24:28]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[28:32]), carry)
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[12:16]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[8:12]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[12:16]), carry)
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b[4:8]), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[:4]), carry)
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b[4:8]), carry)
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, binary.BigEndian.Uint32(b), carry)
} else {
ac, carry = bits.Add32(ac, binary.LittleEndian.Uint32(b), carry)
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(binary.BigEndian.Uint16(b)), carry)
} else {
ac, carry = bits.Add32(ac, uint32(binary.LittleEndian.Uint16(b)), carry)
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac, carry = bits.Add32(ac, uint32(b[0])<<8, carry)
} else {
ac, carry = bits.Add32(ac, uint32(b[0]), carry)
}
}
folded := ipChecksumFold32(ac, carry)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric32Alternate is an alternate reference implementation of
// checksum using 32 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric32Alternate(b []byte, initial uint16) uint16 {
var ac uint32
if cpu.IsBigEndian {
ac = uint32(initial)
} else {
ac = uint32(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
ac += uint32(binary.BigEndian.Uint16(b[32:34]))
ac += uint32(binary.BigEndian.Uint16(b[34:36]))
ac += uint32(binary.BigEndian.Uint16(b[36:38]))
ac += uint32(binary.BigEndian.Uint16(b[38:40]))
ac += uint32(binary.BigEndian.Uint16(b[40:42]))
ac += uint32(binary.BigEndian.Uint16(b[42:44]))
ac += uint32(binary.BigEndian.Uint16(b[44:46]))
ac += uint32(binary.BigEndian.Uint16(b[46:48]))
ac += uint32(binary.BigEndian.Uint16(b[48:50]))
ac += uint32(binary.BigEndian.Uint16(b[50:52]))
ac += uint32(binary.BigEndian.Uint16(b[52:54]))
ac += uint32(binary.BigEndian.Uint16(b[54:56]))
ac += uint32(binary.BigEndian.Uint16(b[56:58]))
ac += uint32(binary.BigEndian.Uint16(b[58:60]))
ac += uint32(binary.BigEndian.Uint16(b[60:62]))
ac += uint32(binary.BigEndian.Uint16(b[62:64]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
ac += uint32(binary.LittleEndian.Uint16(b[32:34]))
ac += uint32(binary.LittleEndian.Uint16(b[34:36]))
ac += uint32(binary.LittleEndian.Uint16(b[36:38]))
ac += uint32(binary.LittleEndian.Uint16(b[38:40]))
ac += uint32(binary.LittleEndian.Uint16(b[40:42]))
ac += uint32(binary.LittleEndian.Uint16(b[42:44]))
ac += uint32(binary.LittleEndian.Uint16(b[44:46]))
ac += uint32(binary.LittleEndian.Uint16(b[46:48]))
ac += uint32(binary.LittleEndian.Uint16(b[48:50]))
ac += uint32(binary.LittleEndian.Uint16(b[50:52]))
ac += uint32(binary.LittleEndian.Uint16(b[52:54]))
ac += uint32(binary.LittleEndian.Uint16(b[54:56]))
ac += uint32(binary.LittleEndian.Uint16(b[56:58]))
ac += uint32(binary.LittleEndian.Uint16(b[58:60]))
ac += uint32(binary.LittleEndian.Uint16(b[60:62]))
ac += uint32(binary.LittleEndian.Uint16(b[62:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
ac += uint32(binary.BigEndian.Uint16(b[16:18]))
ac += uint32(binary.BigEndian.Uint16(b[18:20]))
ac += uint32(binary.BigEndian.Uint16(b[20:22]))
ac += uint32(binary.BigEndian.Uint16(b[22:24]))
ac += uint32(binary.BigEndian.Uint16(b[24:26]))
ac += uint32(binary.BigEndian.Uint16(b[26:28]))
ac += uint32(binary.BigEndian.Uint16(b[28:30]))
ac += uint32(binary.BigEndian.Uint16(b[30:32]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
ac += uint32(binary.LittleEndian.Uint16(b[16:18]))
ac += uint32(binary.LittleEndian.Uint16(b[18:20]))
ac += uint32(binary.LittleEndian.Uint16(b[20:22]))
ac += uint32(binary.LittleEndian.Uint16(b[22:24]))
ac += uint32(binary.LittleEndian.Uint16(b[24:26]))
ac += uint32(binary.LittleEndian.Uint16(b[26:28]))
ac += uint32(binary.LittleEndian.Uint16(b[28:30]))
ac += uint32(binary.LittleEndian.Uint16(b[30:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
ac += uint32(binary.BigEndian.Uint16(b[8:10]))
ac += uint32(binary.BigEndian.Uint16(b[10:12]))
ac += uint32(binary.BigEndian.Uint16(b[12:14]))
ac += uint32(binary.BigEndian.Uint16(b[14:16]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
ac += uint32(binary.LittleEndian.Uint16(b[8:10]))
ac += uint32(binary.LittleEndian.Uint16(b[10:12]))
ac += uint32(binary.LittleEndian.Uint16(b[12:14]))
ac += uint32(binary.LittleEndian.Uint16(b[14:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
ac += uint32(binary.BigEndian.Uint16(b[4:6]))
ac += uint32(binary.BigEndian.Uint16(b[6:8]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
ac += uint32(binary.LittleEndian.Uint16(b[4:6]))
ac += uint32(binary.LittleEndian.Uint16(b[6:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b[:2]))
ac += uint32(binary.BigEndian.Uint16(b[2:4]))
} else {
ac += uint32(binary.LittleEndian.Uint16(b[:2]))
ac += uint32(binary.LittleEndian.Uint16(b[2:4]))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint32(binary.BigEndian.Uint16(b))
} else {
ac += uint32(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint32(b[0]) << 8
} else {
ac += uint32(b[0])
}
}
folded := ipChecksumFold32(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
// checksumGeneric64Alternate is an alternate reference implementation of
// checksum using 64 bit arithmetic for use in testing or when an
// architecture-specific implementation is not available.
func checksumGeneric64Alternate(b []byte, initial uint16) uint16 {
var ac uint64
if cpu.IsBigEndian {
ac = uint64(initial)
} else {
ac = uint64(bits.ReverseBytes16(initial))
}
for len(b) >= 64 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
ac += uint64(binary.BigEndian.Uint32(b[32:36]))
ac += uint64(binary.BigEndian.Uint32(b[36:40]))
ac += uint64(binary.BigEndian.Uint32(b[40:44]))
ac += uint64(binary.BigEndian.Uint32(b[44:48]))
ac += uint64(binary.BigEndian.Uint32(b[48:52]))
ac += uint64(binary.BigEndian.Uint32(b[52:56]))
ac += uint64(binary.BigEndian.Uint32(b[56:60]))
ac += uint64(binary.BigEndian.Uint32(b[60:64]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
ac += uint64(binary.LittleEndian.Uint32(b[32:36]))
ac += uint64(binary.LittleEndian.Uint32(b[36:40]))
ac += uint64(binary.LittleEndian.Uint32(b[40:44]))
ac += uint64(binary.LittleEndian.Uint32(b[44:48]))
ac += uint64(binary.LittleEndian.Uint32(b[48:52]))
ac += uint64(binary.LittleEndian.Uint32(b[52:56]))
ac += uint64(binary.LittleEndian.Uint32(b[56:60]))
ac += uint64(binary.LittleEndian.Uint32(b[60:64]))
}
b = b[64:]
}
if len(b) >= 32 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
ac += uint64(binary.BigEndian.Uint32(b[16:20]))
ac += uint64(binary.BigEndian.Uint32(b[20:24]))
ac += uint64(binary.BigEndian.Uint32(b[24:28]))
ac += uint64(binary.BigEndian.Uint32(b[28:32]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
ac += uint64(binary.LittleEndian.Uint32(b[16:20]))
ac += uint64(binary.LittleEndian.Uint32(b[20:24]))
ac += uint64(binary.LittleEndian.Uint32(b[24:28]))
ac += uint64(binary.LittleEndian.Uint32(b[28:32]))
}
b = b[32:]
}
if len(b) >= 16 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
ac += uint64(binary.BigEndian.Uint32(b[8:12]))
ac += uint64(binary.BigEndian.Uint32(b[12:16]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
ac += uint64(binary.LittleEndian.Uint32(b[8:12]))
ac += uint64(binary.LittleEndian.Uint32(b[12:16]))
}
b = b[16:]
}
if len(b) >= 8 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b[:4]))
ac += uint64(binary.BigEndian.Uint32(b[4:8]))
} else {
ac += uint64(binary.LittleEndian.Uint32(b[:4]))
ac += uint64(binary.LittleEndian.Uint32(b[4:8]))
}
b = b[8:]
}
if len(b) >= 4 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint32(b))
} else {
ac += uint64(binary.LittleEndian.Uint32(b))
}
b = b[4:]
}
if len(b) >= 2 {
if cpu.IsBigEndian {
ac += uint64(binary.BigEndian.Uint16(b))
} else {
ac += uint64(binary.LittleEndian.Uint16(b))
}
b = b[2:]
}
if len(b) >= 1 {
if cpu.IsBigEndian {
ac += uint64(b[0]) << 8
} else {
ac += uint64(b[0])
}
}
folded := ipChecksumFold64(ac, 0)
if !cpu.IsBigEndian {
folded = bits.ReverseBytes16(folded)
}
return folded
}
func ipChecksumFold64(unfolded uint64, initialCarry uint64) uint16 {
sum, carry := bits.Add32(uint32(unfolded>>32), uint32(unfolded&0xffff_ffff), uint32(initialCarry))
// if carry != 0, sum <= 0xffff_fffe, otherwise sum <= 0xffff_ffff
// therefore (sum >> 16) + (sum & 0xffff) + carry <= 0x1_fffe; so there is
// no need to save the carry flag
sum = (sum >> 16) + (sum & 0xffff) + carry
// sum <= 0x1_fffe therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) <= 0xfffe and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func ipChecksumFold32(unfolded uint32, initialCarry uint32) uint16 {
sum := (unfolded >> 16) + (unfolded & 0xffff) + initialCarry
// sum <= 0x1_ffff:
// 0xffff + 0xffff = 0x1_fffe
// initialCarry is 0 or 1, for a combined maximum of 0x1_ffff
sum = (sum >> 16) + (sum & 0xffff)
// sum <= 0x1_0000 therefore this is the last fold needed:
// if (sum >> 16) > 0 then
// (sum >> 16) == 1 && (sum & 0xffff) == 0 and therefore
// the addition will not overflow
// otherwise (sum >> 16) == 0 and sum will be unchanged
sum = (sum >> 16) + (sum & 0xffff)
return uint16(sum)
}
func addrPartialChecksum64(addr []byte, initial, carryIn uint64) (sum, carry uint64) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, uint64(binary.BigEndian.Uint32(addr)), carry)
} else {
sum, carry = bits.Add64(sum, uint64(binary.LittleEndian.Uint32(addr)), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.BigEndian.Uint64(addr[8:]), carry)
} else {
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr), carry)
sum, carry = bits.Add64(sum, binary.LittleEndian.Uint64(addr[8:]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func addrPartialChecksum32(addr []byte, initial, carryIn uint32) (sum, carry uint32) {
sum, carry = initial, carryIn
switch len(addr) {
case 4: // IPv4
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
}
case 16: // IPv6
if cpu.IsBigEndian {
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.BigEndian.Uint32(addr[12:16]), carry)
} else {
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[4:8]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[8:12]), carry)
sum, carry = bits.Add32(sum, binary.LittleEndian.Uint32(addr[12:16]), carry)
}
default:
panic("bad addr length")
}
return sum, carry
}
func pseudoHeaderChecksum64(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint64
if cpu.IsBigEndian {
sum = uint64(totalLen) + uint64(protocol)
} else {
sum = uint64(bits.ReverseBytes16(totalLen)) + uint64(protocol)<<8
}
sum, carry := addrPartialChecksum64(srcAddr, sum, 0)
sum, carry = addrPartialChecksum64(dstAddr, sum, carry)
foldedSum := ipChecksumFold64(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
func pseudoHeaderChecksum32(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
var sum uint32
if cpu.IsBigEndian {
sum = uint32(totalLen) + uint32(protocol)
} else {
sum = uint32(bits.ReverseBytes16(totalLen)) + uint32(protocol)<<8
}
sum, carry := addrPartialChecksum32(srcAddr, sum, 0)
sum, carry = addrPartialChecksum32(dstAddr, sum, carry)
foldedSum := ipChecksumFold32(sum, carry)
if !cpu.IsBigEndian {
foldedSum = bits.ReverseBytes16(foldedSum)
}
return foldedSum
}
// PseudoHeaderChecksum computes an IP pseudo-header checksum. srcAddr and
// dstAddr must be 4 or 16 bytes in length.
func PseudoHeaderChecksum(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint16 {
if strconv.IntSize < 64 {
return pseudoHeaderChecksum32(protocol, srcAddr, dstAddr, totalLen)
}
return pseudoHeaderChecksum64(protocol, srcAddr, dstAddr, totalLen)
}

View File

@@ -1,23 +0,0 @@
package tschecksum
import "golang.org/x/sys/cpu"
var checksum = checksumAMD64
// Checksum computes an IP checksum starting with the provided initial value.
// The length of data should be at least 128 bytes for best performance. Smaller
// buffers will still compute a correct result.
func Checksum(data []byte, initial uint16) uint16 {
return checksum(data, initial)
}
func init() {
if cpu.X86.HasAVX && cpu.X86.HasAVX2 && cpu.X86.HasBMI2 {
checksum = checksumAVX2
return
}
if cpu.X86.HasSSE2 {
checksum = checksumSSE2
return
}
}

View File

@@ -1,18 +0,0 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
package tschecksum
// checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)
//
//go:noescape
func checksumAVX2(b []byte, initial uint16) uint16
// checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)
//
//go:noescape
func checksumSSE2(b []byte, initial uint16) uint16
// checksumAMD64 computes an IP checksum using amd64 baseline instructions
//
//go:noescape
func checksumAMD64(b []byte, initial uint16) uint16

View File

@@ -1,851 +0,0 @@
// Code generated by command: go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go. DO NOT EDIT.
#include "textflag.h"
DATA xmmLoadMasks<>+0(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
DATA xmmLoadMasks<>+16(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"
DATA xmmLoadMasks<>+32(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+48(SB)/16, $"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+64(SB)/16, $"\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+80(SB)/16, $"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
DATA xmmLoadMasks<>+96(SB)/16, $"\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff"
GLOBL xmmLoadMasks<>(SB), RODATA|NOPTR, $112
// func checksumAVX2(b []byte, initial uint16) uint16
// Requires: AVX, AVX2, BMI2
TEXT ·checksumAVX2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
VPXOR Y0, Y0, Y0
VPXOR Y1, Y1, Y1
VPXOR Y2, Y2, Y2
VPXOR Y3, Y3, Y3
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 16(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 32(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 48(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 64(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 80(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 96(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 112(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 128(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 144(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 160(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 176(DX), Y4
VPADDD Y4, Y3, Y3
VPMOVZXWD 192(DX), Y4
VPADDD Y4, Y0, Y0
VPMOVZXWD 208(DX), Y4
VPADDD Y4, Y1, Y1
VPMOVZXWD 224(DX), Y4
VPADDD Y4, Y2, Y2
VPMOVZXWD 240(DX), Y4
VPADDD Y4, Y3, Y3
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
VPMOVZXWD (DX), Y4
VPADDD Y4, Y0, Y0
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
VMOVDQU -16(DX)(BX*1), X4
VPAND -16(CX)(BX*8), X4, X4
VPMOVZXWD X4, Y4
VPADDD Y4, Y0, Y0
doneSIMD:
// Multi-chain loop is done, combine the accumulators
VPADDD Y1, Y0, Y0
VPADDD Y2, Y0, Y0
VPADDD Y3, Y0, Y0
// extract the YMM into a pair of XMM and sum them
VEXTRACTI128 $0x01, Y0, X1
VPADDD X0, X1, X0
// extract the XMM into GP64
VPEXTRQ $0x00, X0, CX
VPEXTRQ $0x01, X0, DX
// no more AVX code, clear upper registers to avoid SSE slowdowns
VZEROUPPER
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
RORXQ $0x20, AX, CX
ADCL CX, AX
RORXL $0x10, AX, CX
ADCW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumSSE2(b []byte, initial uint16) uint16
// Requires: SSE2
TEXT ·checksumSSE2(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// skip all SIMD for small buffers
CMPQ BX, $0x00000100
JGE startSIMD
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
JMP foldAndReturn
startSIMD:
PXOR X0, X0
PXOR X1, X1
PXOR X2, X2
PXOR X3, X3
PXOR X4, X4
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
// Number of 256 byte iterations
SHRQ $0x08, CX
JZ smallLoop
bigLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 16(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 32(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 48(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 64(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 80(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 96(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 112(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 128(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 144(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 160(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 176(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
MOVOU 192(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X2
MOVOU 208(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X1
PADDD X6, X3
MOVOU 224(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X2
PADDD X6, X0
MOVOU 240(DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X3
PADDD X6, X1
ADDQ $0x00000100, DX
DECQ CX
JNZ bigLoop
CMPQ BX, $0x10
JLT doneSmallLoop
// now read a single 16 byte unit of data at a time
smallLoop:
MOVOU (DX), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
ADDQ $0x10, DX
SUBQ $0x10, BX
CMPQ BX, $0x10
JGE smallLoop
doneSmallLoop:
CMPQ BX, $0x00
JE doneSIMD
// There are between 1 and 15 bytes remaining. Perform an overlapped read.
LEAQ xmmLoadMasks<>+0(SB), CX
MOVOU -16(DX)(BX*1), X5
PAND -16(CX)(BX*8), X5
MOVOA X5, X6
PUNPCKHWL X4, X5
PUNPCKLWL X4, X6
PADDD X5, X0
PADDD X6, X1
doneSIMD:
// Multi-chain loop is done, combine the accumulators
PADDD X1, X0
PADDD X2, X0
PADDD X3, X0
// extract the XMM into GP64
MOVQ X0, CX
PSRLDQ $0x08, X0
MOVQ X0, DX
ADDQ CX, AX
ADCQ DX, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET
// func checksumAMD64(b []byte, initial uint16) uint16
TEXT ·checksumAMD64(SB), NOSPLIT|NOFRAME, $0-34
MOVWQZX initial+24(FP), AX
XCHGB AH, AL
MOVQ b_base+0(FP), DX
MOVQ b_len+8(FP), BX
// handle odd length buffers; they are difficult to handle in general
TESTQ $0x00000001, BX
JZ lengthIsEven
MOVBQZX -1(DX)(BX*1), CX
DECQ BX
ADDQ CX, AX
lengthIsEven:
// handle tiny buffers (<=31 bytes) specially
CMPQ BX, $0x1f
JGT bufferIsNotTiny
XORQ CX, CX
XORQ SI, SI
XORQ DI, DI
// shift twice to start because length is guaranteed to be even
// n = n >> 2; CF = originalN & 2
SHRQ $0x02, BX
JNC handleTiny4
// tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]
MOVWQZX (DX), CX
ADDQ $0x02, DX
handleTiny4:
// n = n >> 1; CF = originalN & 4
SHRQ $0x01, BX
JNC handleTiny8
// tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]
MOVLQZX (DX), SI
ADDQ $0x04, DX
handleTiny8:
// n = n >> 1; CF = originalN & 8
SHRQ $0x01, BX
JNC handleTiny16
// tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]
MOVQ (DX), DI
ADDQ $0x08, DX
handleTiny16:
// n = n >> 1; CF = originalN & 16
// n == 0 now, otherwise we would have branched after comparing with tinyBufferSize
SHRQ $0x01, BX
JNC handleTinyFinish
ADDQ (DX), AX
ADCQ 8(DX), AX
handleTinyFinish:
// CF should be included from the previous add, so we use ADCQ.
// If we arrived via the JNC above, then CF=0 due to the branch condition,
// so ADCQ will still produce the correct result.
ADCQ CX, AX
ADCQ SI, AX
ADCQ DI, AX
JMP foldAndReturn
bufferIsNotTiny:
// Number of 256 byte iterations into loop counter
MOVQ BX, CX
// Update number of bytes remaining after the loop completes
ANDQ $0xff, BX
SHRQ $0x08, CX
JZ startCleanup
CLC
XORQ SI, SI
XORQ DI, DI
XORQ R8, R8
XORQ R9, R9
XORQ R10, R10
XORQ R11, R11
XORQ R12, R12
bigLoop:
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ 32(DX), DI
ADCQ 40(DX), DI
ADCQ 48(DX), DI
ADCQ 56(DX), DI
ADCQ $0x00, R8
ADDQ 64(DX), R9
ADCQ 72(DX), R9
ADCQ 80(DX), R9
ADCQ 88(DX), R9
ADCQ $0x00, R10
ADDQ 96(DX), R11
ADCQ 104(DX), R11
ADCQ 112(DX), R11
ADCQ 120(DX), R11
ADCQ $0x00, R12
ADDQ 128(DX), AX
ADCQ 136(DX), AX
ADCQ 144(DX), AX
ADCQ 152(DX), AX
ADCQ $0x00, SI
ADDQ 160(DX), DI
ADCQ 168(DX), DI
ADCQ 176(DX), DI
ADCQ 184(DX), DI
ADCQ $0x00, R8
ADDQ 192(DX), R9
ADCQ 200(DX), R9
ADCQ 208(DX), R9
ADCQ 216(DX), R9
ADCQ $0x00, R10
ADDQ 224(DX), R11
ADCQ 232(DX), R11
ADCQ 240(DX), R11
ADCQ 248(DX), R11
ADCQ $0x00, R12
ADDQ $0x00000100, DX
SUBQ $0x01, CX
JNZ bigLoop
ADDQ SI, AX
ADCQ DI, AX
ADCQ R8, AX
ADCQ R9, AX
ADCQ R10, AX
ADCQ R11, AX
ADCQ R12, AX
// accumulate CF (twice, in case the first time overflows)
ADCQ $0x00, AX
ADCQ $0x00, AX
startCleanup:
// Accumulate carries in this register. It is never expected to overflow.
XORQ SI, SI
// We will perform an overlapped read for buffers with length not a multiple of 8.
// Overlapped in this context means some memory will be read twice, but a shift will
// eliminate the duplicated data. This extra read is performed at the end of the buffer to
// preserve any alignment that may exist for the start of the buffer.
MOVQ BX, CX
SHRQ $0x03, BX
ANDQ $0x07, CX
JZ handleRemaining8
LEAQ (DX)(BX*8), DI
MOVQ -8(DI)(CX*1), DI
// Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)
SHLQ $0x03, CX
NEGQ CX
ADDQ $0x40, CX
SHRQ CL, DI
ADDQ DI, AX
ADCQ $0x00, SI
handleRemaining8:
SHRQ $0x01, BX
JNC handleRemaining16
ADDQ (DX), AX
ADCQ $0x00, SI
ADDQ $0x08, DX
handleRemaining16:
SHRQ $0x01, BX
JNC handleRemaining32
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ $0x00, SI
ADDQ $0x10, DX
handleRemaining32:
SHRQ $0x01, BX
JNC handleRemaining64
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ $0x00, SI
ADDQ $0x20, DX
handleRemaining64:
SHRQ $0x01, BX
JNC handleRemaining128
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ $0x00, SI
ADDQ $0x40, DX
handleRemaining128:
SHRQ $0x01, BX
JNC handleRemainingComplete
ADDQ (DX), AX
ADCQ 8(DX), AX
ADCQ 16(DX), AX
ADCQ 24(DX), AX
ADCQ 32(DX), AX
ADCQ 40(DX), AX
ADCQ 48(DX), AX
ADCQ 56(DX), AX
ADCQ 64(DX), AX
ADCQ 72(DX), AX
ADCQ 80(DX), AX
ADCQ 88(DX), AX
ADCQ 96(DX), AX
ADCQ 104(DX), AX
ADCQ 112(DX), AX
ADCQ 120(DX), AX
ADCQ $0x00, SI
ADDQ $0x80, DX
handleRemainingComplete:
ADDQ SI, AX
foldAndReturn:
// add CF and fold
MOVL AX, CX
ADCQ $0x00, CX
SHRQ $0x20, AX
ADDQ CX, AX
MOVWQZX AX, CX
SHRQ $0x10, AX
ADDQ CX, AX
MOVW AX, CX
SHRQ $0x10, AX
ADDW CX, AX
ADCW $0x00, AX
XCHGB AH, AL
MOVW AX, ret+32(FP)
RET

View File

@@ -1,15 +0,0 @@
// This file contains IP checksum algorithms that are not specific to any
// architecture and don't use hardware acceleration.
//go:build !amd64
package tschecksum
import "strconv"
func Checksum(data []byte, initial uint16) uint16 {
if strconv.IntSize < 64 {
return checksumGeneric32(data, initial)
}
return checksumGeneric64(data, initial)
}

View File

@@ -1,578 +0,0 @@
//go:build ignore
//go:generate go run generate_amd64.go -out checksum_generated_amd64.s -stubs checksum_generated_amd64.go
package main
import (
"fmt"
"math"
"math/bits"
"github.com/mmcloughlin/avo/operand"
"github.com/mmcloughlin/avo/reg"
)
const checksumSignature = "func(b []byte, initial uint16) uint16"
func loadParams() (accum, buf, n reg.GPVirtual) {
accum, buf, n = GP64(), GP64(), GP64()
Load(Param("initial"), accum)
XCHGB(accum.As8H(), accum.As8L())
Load(Param("b").Base(), buf)
Load(Param("b").Len(), n)
return
}
type simdStrategy int
const (
sse2 = iota
avx2
)
const tinyBufferSize = 31 // A buffer is tiny if it has at most 31 bytes.
func generateSIMDChecksum(name, doc string, minSIMDSize, chains int, strategy simdStrategy) {
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
Pragma("noescape")
Doc(doc)
accum64, buf, n := loadParams()
handleOddLength(n, buf, accum64)
// no chance of overflow because accum64 was initialized by a uint16 and
// handleOddLength adds at most a uint8
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
Label("bufferIsNotTiny")
const simdReadSize = 16
if minSIMDSize > tinyBufferSize {
Comment("skip all SIMD for small buffers")
if minSIMDSize <= math.MaxUint8 {
CMPQ(n, operand.U8(minSIMDSize))
} else {
CMPQ(n, operand.U32(minSIMDSize))
}
JGE(operand.LabelRef("startSIMD"))
handleRemaining(n, buf, accum64, minSIMDSize-1)
JMP(operand.LabelRef("foldAndReturn"))
}
Label("startSIMD")
// chains is the number of accumulators to use. This improves speed via
// reduced data dependency. We combine the accumulators once when the big
// loop is complete.
simdAccumulate := make([]reg.VecVirtual, chains)
for i := range simdAccumulate {
switch strategy {
case sse2:
simdAccumulate[i] = XMM()
PXOR(simdAccumulate[i], simdAccumulate[i])
case avx2:
simdAccumulate[i] = YMM()
VPXOR(simdAccumulate[i], simdAccumulate[i], simdAccumulate[i])
}
}
var zero reg.VecVirtual
if strategy == sse2 {
zero = XMM()
PXOR(zero, zero)
}
// Number of loads per big loop
const unroll = 16
// Number of bytes
loopSize := uint64(simdReadSize * unroll)
if bits.Len64(loopSize) != bits.Len64(loopSize-1)+1 {
panic("loopSize is not a power of 2")
}
loopCount := GP64()
MOVQ(n, loopCount)
Comment("Update number of bytes remaining after the loop completes")
ANDQ(operand.Imm(loopSize-1), n)
Comment(fmt.Sprintf("Number of %d byte iterations", loopSize))
SHRQ(operand.Imm(uint64(bits.Len64(loopSize-1))), loopCount)
JZ(operand.LabelRef("smallLoop"))
Label("bigLoop")
for i := 0; i < unroll; i++ {
chain := i % chains
switch strategy {
case sse2:
sse2AccumulateStep(i*simdReadSize, buf, zero, simdAccumulate[chain], simdAccumulate[(chain+chains/2)%chains])
case avx2:
avx2AccumulateStep(i*simdReadSize, buf, simdAccumulate[chain])
}
}
ADDQ(operand.U32(loopSize), buf)
DECQ(loopCount)
JNZ(operand.LabelRef("bigLoop"))
Label("bigCleanup")
CMPQ(n, operand.Imm(uint64(simdReadSize)))
JLT(operand.LabelRef("doneSmallLoop"))
Commentf("now read a single %d byte unit of data at a time", simdReadSize)
Label("smallLoop")
switch strategy {
case sse2:
sse2AccumulateStep(0, buf, zero, simdAccumulate[0], simdAccumulate[1])
case avx2:
avx2AccumulateStep(0, buf, simdAccumulate[0])
}
ADDQ(operand.Imm(uint64(simdReadSize)), buf)
SUBQ(operand.Imm(uint64(simdReadSize)), n)
CMPQ(n, operand.Imm(uint64(simdReadSize)))
JGE(operand.LabelRef("smallLoop"))
Label("doneSmallLoop")
CMPQ(n, operand.Imm(0))
JE(operand.LabelRef("doneSIMD"))
Commentf("There are between 1 and %d bytes remaining. Perform an overlapped read.", simdReadSize-1)
maskDataPtr := GP64()
LEAQ(operand.NewDataAddr(operand.NewStaticSymbol("xmmLoadMasks"), 0), maskDataPtr)
dataAddr := operand.Mem{Index: n, Scale: 1, Base: buf, Disp: -simdReadSize}
// scale 8 is only correct here because n is guaranteed to be even and we
// do not generate masks for odd lengths
maskAddr := operand.Mem{Base: maskDataPtr, Index: n, Scale: 8, Disp: -16}
remainder := XMM()
switch strategy {
case sse2:
MOVOU(dataAddr, remainder)
PAND(maskAddr, remainder)
low := XMM()
MOVOA(remainder, low)
PUNPCKHWL(zero, remainder)
PUNPCKLWL(zero, low)
PADDD(remainder, simdAccumulate[0])
PADDD(low, simdAccumulate[1])
case avx2:
// Note: this is very similar to the sse2 path but MOVOU has a massive
// performance hit if used here, presumably due to switching between SSE
// and AVX2 modes.
VMOVDQU(dataAddr, remainder)
VPAND(maskAddr, remainder, remainder)
temp := YMM()
VPMOVZXWD(remainder, temp)
VPADDD(temp, simdAccumulate[0], simdAccumulate[0])
}
Label("doneSIMD")
Comment("Multi-chain loop is done, combine the accumulators")
for i := range simdAccumulate {
if i == 0 {
continue
}
switch strategy {
case sse2:
PADDD(simdAccumulate[i], simdAccumulate[0])
case avx2:
VPADDD(simdAccumulate[i], simdAccumulate[0], simdAccumulate[0])
}
}
if strategy == avx2 {
Comment("extract the YMM into a pair of XMM and sum them")
tmp := YMM()
VEXTRACTI128(operand.Imm(1), simdAccumulate[0], tmp.AsX())
xAccumulate := XMM()
VPADDD(simdAccumulate[0].AsX(), tmp.AsX(), xAccumulate)
simdAccumulate = []reg.VecVirtual{xAccumulate}
}
Comment("extract the XMM into GP64")
low, high := GP64(), GP64()
switch strategy {
case sse2:
MOVQ(simdAccumulate[0], low)
PSRLDQ(operand.Imm(8), simdAccumulate[0])
MOVQ(simdAccumulate[0], high)
case avx2:
VPEXTRQ(operand.Imm(0), simdAccumulate[0], low)
VPEXTRQ(operand.Imm(1), simdAccumulate[0], high)
Comment("no more AVX code, clear upper registers to avoid SSE slowdowns")
VZEROUPPER()
}
ADDQ(low, accum64)
ADCQ(high, accum64)
Label("foldAndReturn")
foldWithCF(accum64, strategy == avx2)
XCHGB(accum64.As8H(), accum64.As8L())
Store(accum64.As16(), ReturnIndex(0))
RET()
}
// handleOddLength generates instructions to incorporate the last byte into
// accum64 if the length is odd. CF may be set if accum64 overflows; be sure to
// handle that if overflow is possible.
func handleOddLength(n, buf, accum64 reg.GPVirtual) {
Comment("handle odd length buffers; they are difficult to handle in general")
TESTQ(operand.U32(1), n)
JZ(operand.LabelRef("lengthIsEven"))
tmp := GP64()
MOVBQZX(operand.Mem{Base: buf, Index: n, Scale: 1, Disp: -1}, tmp)
DECQ(n)
ADDQ(tmp, accum64)
Label("lengthIsEven")
}
func sse2AccumulateStep(offset int, buf reg.GPVirtual, zero, accumulate1, accumulate2 reg.VecVirtual) {
high, low := XMM(), XMM()
MOVOU(operand.Mem{Disp: offset, Base: buf}, high)
MOVOA(high, low)
PUNPCKHWL(zero, high)
PUNPCKLWL(zero, low)
PADDD(high, accumulate1)
PADDD(low, accumulate2)
}
func avx2AccumulateStep(offset int, buf reg.GPVirtual, accumulate reg.VecVirtual) {
tmp := YMM()
VPMOVZXWD(operand.Mem{Disp: offset, Base: buf}, tmp)
VPADDD(tmp, accumulate, accumulate)
}
func generateAMD64Checksum(name, doc string) {
TEXT(name, NOSPLIT|NOFRAME, checksumSignature)
Pragma("noescape")
Doc(doc)
accum64, buf, n := loadParams()
handleOddLength(n, buf, accum64)
// no chance of overflow because accum64 was initialized by a uint16 and
// handleOddLength adds at most a uint8
handleTinyBuffers(n, buf, accum64, operand.LabelRef("foldAndReturn"), operand.LabelRef("bufferIsNotTiny"))
Label("bufferIsNotTiny")
const (
// numChains is the number of accumulators and carry counters to use.
// This improves speed via reduced data dependency. We combine the
// accumulators and carry counters once when the loop is complete.
numChains = 4
unroll = 32 // The number of 64-bit reads to perform per iteration of the loop.
loopSize = 8 * unroll // The number of bytes read per iteration of the loop.
)
if bits.Len(loopSize) != bits.Len(loopSize-1)+1 {
panic("loopSize is not a power of 2")
}
loopCount := GP64()
Comment(fmt.Sprintf("Number of %d byte iterations into loop counter", loopSize))
MOVQ(n, loopCount)
Comment("Update number of bytes remaining after the loop completes")
ANDQ(operand.Imm(loopSize-1), n)
SHRQ(operand.Imm(uint64(bits.Len(loopSize-1))), loopCount)
JZ(operand.LabelRef("startCleanup"))
CLC()
chains := make([]struct {
accum reg.GPVirtual
carries reg.GPVirtual
}, numChains)
for i := range chains {
if i == 0 {
chains[i].accum = accum64
} else {
chains[i].accum = GP64()
XORQ(chains[i].accum, chains[i].accum)
}
chains[i].carries = GP64()
XORQ(chains[i].carries, chains[i].carries)
}
Label("bigLoop")
var curChain int
for i := 0; i < unroll; i++ {
// It is significantly faster to use a ADCX/ADOX pair instead of plain
// ADC, which results in two dependency chains, however those require
// ADX support, which was added after AVX2. If AVX2 is available, that's
// even better than ADCX/ADOX.
//
// However, multiple dependency chains using multiple accumulators and
// occasionally storing CF into temporary counters seems to work almost
// as well.
addr := operand.Mem{Disp: i * 8, Base: buf}
if i%4 == 0 {
if i > 0 {
ADCQ(operand.Imm(0), chains[curChain].carries)
curChain = (curChain + 1) % len(chains)
}
ADDQ(addr, chains[curChain].accum)
} else {
ADCQ(addr, chains[curChain].accum)
}
}
ADCQ(operand.Imm(0), chains[curChain].carries)
ADDQ(operand.U32(loopSize), buf)
SUBQ(operand.Imm(1), loopCount)
JNZ(operand.LabelRef("bigLoop"))
for i := range chains {
if i == 0 {
ADDQ(chains[i].carries, accum64)
continue
}
ADCQ(chains[i].accum, accum64)
ADCQ(chains[i].carries, accum64)
}
accumulateCF(accum64)
Label("startCleanup")
handleRemaining(n, buf, accum64, loopSize-1)
Label("foldAndReturn")
foldWithCF(accum64, false)
XCHGB(accum64.As8H(), accum64.As8L())
Store(accum64.As16(), ReturnIndex(0))
RET()
}
// handleTinyBuffers computes checksums if the buffer length (the n parameter)
// is less than 32. After computing the checksum, a jump to returnLabel will
// be executed. Otherwise, if the buffer length is at least 32, nothing will be
// modified; a jump to continueLabel will be executed instead.
//
// When jumping to returnLabel, CF may be set and must be accommodated e.g.
// using foldWithCF or accumulateCF.
//
// Anecdotally, this appears to be faster than attempting to coordinate an
// overlapped read (which would also require special handling for buffers
// smaller than 8).
func handleTinyBuffers(n, buf, accum reg.GPVirtual, returnLabel, continueLabel operand.LabelRef) {
Comment("handle tiny buffers (<=31 bytes) specially")
CMPQ(n, operand.Imm(tinyBufferSize))
JGT(continueLabel)
tmp2, tmp4, tmp8 := GP64(), GP64(), GP64()
XORQ(tmp2, tmp2)
XORQ(tmp4, tmp4)
XORQ(tmp8, tmp8)
Comment("shift twice to start because length is guaranteed to be even",
"n = n >> 2; CF = originalN & 2")
SHRQ(operand.Imm(2), n)
JNC(operand.LabelRef("handleTiny4"))
Comment("tmp2 = binary.LittleEndian.Uint16(buf[:2]); buf = buf[2:]")
MOVWQZX(operand.Mem{Base: buf}, tmp2)
ADDQ(operand.Imm(2), buf)
Label("handleTiny4")
Comment("n = n >> 1; CF = originalN & 4")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTiny8"))
Comment("tmp4 = binary.LittleEndian.Uint32(buf[:4]); buf = buf[4:]")
MOVLQZX(operand.Mem{Base: buf}, tmp4)
ADDQ(operand.Imm(4), buf)
Label("handleTiny8")
Comment("n = n >> 1; CF = originalN & 8")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTiny16"))
Comment("tmp8 = binary.LittleEndian.Uint64(buf[:8]); buf = buf[8:]")
MOVQ(operand.Mem{Base: buf}, tmp8)
ADDQ(operand.Imm(8), buf)
Label("handleTiny16")
Comment("n = n >> 1; CF = originalN & 16",
"n == 0 now, otherwise we would have branched after comparing with tinyBufferSize")
SHRQ(operand.Imm(1), n)
JNC(operand.LabelRef("handleTinyFinish"))
ADDQ(operand.Mem{Base: buf}, accum)
ADCQ(operand.Mem{Base: buf, Disp: 8}, accum)
Label("handleTinyFinish")
Comment("CF should be included from the previous add, so we use ADCQ.",
"If we arrived via the JNC above, then CF=0 due to the branch condition,",
"so ADCQ will still produce the correct result.")
ADCQ(tmp2, accum)
ADCQ(tmp4, accum)
ADCQ(tmp8, accum)
JMP(returnLabel)
}
// handleRemaining generates a series of conditional unrolled additions,
// starting with 8 bytes long and doubling each time until the length reaches
// max. This is the reverse order of what may be intuitive, but makes the branch
// conditions convenient to compute: perform one right shift each time and test
// against CF.
//
// When done, CF may be set and must be accommodated e.g., using foldWithCF or
// accumulateCF.
//
// If n is not a multiple of 8, an extra 64 bit read at the end of the buffer
// will be performed, overlapping with data that will be read later. The
// duplicate data will be shifted off.
//
// The original buffer length must have been at least 8 bytes long, even if
// n < 8, otherwise this will access memory before the start of the buffer,
// which may be unsafe.
func handleRemaining(n, buf, accum64 reg.GPVirtual, max int) {
Comment("Accumulate carries in this register. It is never expected to overflow.")
carries := GP64()
XORQ(carries, carries)
Comment("We will perform an overlapped read for buffers with length not a multiple of 8.",
"Overlapped in this context means some memory will be read twice, but a shift will",
"eliminate the duplicated data. This extra read is performed at the end of the buffer to",
"preserve any alignment that may exist for the start of the buffer.")
leftover := reg.RCX
MOVQ(n, leftover)
SHRQ(operand.Imm(3), n) // n is now the number of 64 bit reads remaining
ANDQ(operand.Imm(0x7), leftover) // leftover is now the number of bytes to read from the end
JZ(operand.LabelRef("handleRemaining8"))
endBuf := GP64()
// endBuf is the position near the end of the buffer that is just past the
// last multiple of 8: (buf + len(buf)) & ^0x7
LEAQ(operand.Mem{Base: buf, Index: n, Scale: 8}, endBuf)
overlapRead := GP64()
// equivalent to overlapRead = binary.LittleEndian.Uint64(buf[len(buf)-8:len(buf)])
MOVQ(operand.Mem{Base: endBuf, Index: leftover, Scale: 1, Disp: -8}, overlapRead)
Comment("Shift out the duplicated data: overlapRead = overlapRead >> (64 - leftoverBytes*8)")
SHLQ(operand.Imm(3), leftover) // leftover = leftover * 8
NEGQ(leftover) // leftover = -leftover; this completes the (-leftoverBytes*8) part of the expression
ADDQ(operand.Imm(64), leftover) // now we have (64 - leftoverBytes*8)
SHRQ(reg.CL, overlapRead) // shift right by (64 - leftoverBytes*8); CL is the low 8 bits of leftover (set to RCX above) and variable shift only accepts CL
ADDQ(overlapRead, accum64)
ADCQ(operand.Imm(0), carries)
for curBytes := 8; curBytes <= max; curBytes *= 2 {
Label(fmt.Sprintf("handleRemaining%d", curBytes))
SHRQ(operand.Imm(1), n)
if curBytes*2 <= max {
JNC(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
} else {
JNC(operand.LabelRef("handleRemainingComplete"))
}
numLoads := curBytes / 8
for i := 0; i < numLoads; i++ {
addr := operand.Mem{Base: buf, Disp: i * 8}
// It is possible to add the multiple dependency chains trick here
// that generateAMD64Checksum uses but anecdotally it does not
// appear to outweigh the cost.
if i == 0 {
ADDQ(addr, accum64)
continue
}
ADCQ(addr, accum64)
}
ADCQ(operand.Imm(0), carries)
if curBytes > math.MaxUint8 {
ADDQ(operand.U32(uint64(curBytes)), buf)
} else {
ADDQ(operand.U8(uint64(curBytes)), buf)
}
if curBytes*2 >= max {
continue
}
JMP(operand.LabelRef(fmt.Sprintf("handleRemaining%d", curBytes*2)))
}
Label("handleRemainingComplete")
ADDQ(carries, accum64)
}
func accumulateCF(accum64 reg.GPVirtual) {
Comment("accumulate CF (twice, in case the first time overflows)")
// accum64 += CF
ADCQ(operand.Imm(0), accum64)
// accum64 += CF again if the previous add overflowed. The previous add was
// 0 or 1. If it overflowed, then accum64 == 0, so adding another 1 can
// never overflow.
ADCQ(operand.Imm(0), accum64)
}
// foldWithCF generates instructions to fold accum (a GP64) into a 16-bit value
// according to ones-complement arithmetic. BMI2 instructions will be used if
// allowBMI2 is true (requires fewer instructions).
func foldWithCF(accum reg.GPVirtual, allowBMI2 bool) {
Comment("add CF and fold")
// CF|accum max value starts as 0x1_ffff_ffff_ffff_ffff
tmp := GP64()
if allowBMI2 {
// effectively, tmp = accum >> 32 (technically, this is a rotate)
RORXQ(operand.Imm(32), accum, tmp)
// accum as uint32 = uint32(accum) + uint32(tmp64) + CF; max value 0xffff_ffff + CF set
ADCL(tmp.As32(), accum.As32())
// effectively, tmp64 as uint32 = uint32(accum) >> 16 (also a rotate)
RORXL(operand.Imm(16), accum.As32(), tmp.As32())
// accum as uint16 = uint16(accum) + uint16(tmp) + CF; max value 0xffff + CF unset or 0xfffe + CF set
ADCW(tmp.As16(), accum.As16())
} else {
// tmp = uint32(accum); max value 0xffff_ffff
// MOVL clears the upper 32 bits of a GP64 so this is equivalent to the
// non-existent MOVLQZX.
MOVL(accum.As32(), tmp.As32())
// tmp += CF; max value 0x1_0000_0000, CF unset
ADCQ(operand.Imm(0), tmp)
// accum = accum >> 32; max value 0xffff_ffff
SHRQ(operand.Imm(32), accum)
// accum = accum + tmp; max value 0x1_ffff_ffff + CF unset
ADDQ(tmp, accum)
// tmp = uint16(accum); max value 0xffff
MOVWQZX(accum.As16(), tmp)
// accum = accum >> 16; max value 0x1_ffff
SHRQ(operand.Imm(16), accum)
// accum = accum + tmp; max value 0x2_fffe + CF unset
ADDQ(tmp, accum)
// tmp as uint16 = uint16(accum); max value 0xffff
MOVW(accum.As16(), tmp.As16())
// accum = accum >> 16; max value 0x2
SHRQ(operand.Imm(16), accum)
// accum as uint16 = uint16(accum) + uint16(tmp); max value 0xffff + CF unset or 0x2 + CF set
ADDW(tmp.As16(), accum.As16())
}
// accum as uint16 += CF; will not overflow: either CF was 0 or accum <= 0xfffe
ADCW(operand.Imm(0), accum.As16())
}
func generateLoadMasks() {
var offset int
// xmmLoadMasks is a table of masks that can be used with PAND to zero all but the last N bytes in an XMM, N=2,4,6,8,10,12,14
GLOBL("xmmLoadMasks", RODATA|NOPTR)
for n := 2; n < 16; n += 2 {
var pattern [16]byte
for i := 0; i < len(pattern); i++ {
if i < len(pattern)-n {
pattern[i] = 0
continue
}
pattern[i] = 0xff
}
DATA(offset, operand.String(pattern[:]))
offset += len(pattern)
}
}
func main() {
generateLoadMasks()
generateSIMDChecksum("checksumAVX2", "checksumAVX2 computes an IP checksum using amd64 v3 instructions (AVX2, BMI2)", 256, 4, avx2)
generateSIMDChecksum("checksumSSE2", "checksumSSE2 computes an IP checksum using amd64 baseline instructions (SSE2)", 256, 4, sse2)
generateAMD64Checksum("checksumAMD64", "checksumAMD64 computes an IP checksum using amd64 baseline instructions")
Generate()
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/go-ole/go-ole"
"github.com/go-ole/go-ole/oleutil"
"github.com/scjalliance/comshim"
)
// Firewall related API constants.
@@ -249,10 +250,7 @@ func FirewallRuleExistsByName(rules *ole.IDispatch, name string) (bool, error) {
// then:
// dispatch firewallAPIRelease(u, fwp)
func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
if err != nil {
return nil, nil, fmt.Errorf("Failed to initialize COM: %s", err)
}
comshim.Add(1)
unknown, err := oleutil.CreateObject("HNetCfg.FwPolicy2")
if err != nil {
@@ -272,5 +270,5 @@ func firewallAPIInit() (*ole.IUnknown, *ole.IDispatch, error) {
func firewallAPIRelease(u *ole.IUnknown, fwp *ole.IDispatch) {
fwp.Release()
u.Release()
ole.CoUninitialize()
comshim.Done()
}

View File

@@ -385,25 +385,3 @@ func (luid LUID) SetDNS(family AddressFamily, servers []netip.Addr, domains []st
func (luid LUID) FlushDNS(family AddressFamily) error {
return luid.SetDNS(family, nil, nil)
}
func (luid LUID) DisableDNSRegistration() error {
guid, err := luid.GUID()
if err != nil {
return err
}
dnsInterfaceSettings := &DnsInterfaceSettings{
Version: DnsInterfaceSettingsVersion1,
Flags: DnsInterfaceSettingsFlagRegistrationEnabled,
RegistrationEnabled: 0,
}
// For >= Windows 10 1809
err = SetInterfaceDnsSettings(*guid, dnsInterfaceSettings)
if err == nil || !errors.Is(err, windows.ERROR_PROC_NOT_FOUND) {
return err
}
// For < Windows 10 1809
return luid.fallbackDisableDNSRegistration()
}

View File

@@ -51,11 +51,10 @@ func runNetsh(cmds []string) error {
}
const (
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateDisableRegistration = "interface ipv6 set dnsservers name=%d register=none"
netshCmdTemplateFlush4 = "interface ipv4 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateFlush6 = "interface ipv6 set dnsservers name=%d source=static address=none validate=no register=both"
netshCmdTemplateAdd4 = "interface ipv4 add dnsservers name=%d address=%s validate=no"
netshCmdTemplateAdd6 = "interface ipv6 add dnsservers name=%d address=%s validate=no"
)
func (luid LUID) fallbackSetDNSForFamily(family AddressFamily, dnses []netip.Addr) error {
@@ -107,13 +106,3 @@ func (luid LUID) fallbackSetDNSDomain(domain string) error {
key.Close()
return err
}
func (luid LUID) fallbackDisableDNSRegistration() error {
// the DNS registration setting is shared for both IPv4 and IPv6
ipif, err := luid.IPInterface(windows.AF_INET)
if err != nil {
return err
}
cmd := fmt.Sprintf(netshCmdTemplateDisableRegistration, ipif.InterfaceIndex)
return runNetsh([]string{cmd})
}

149
lwip.go Normal file
View File

@@ -0,0 +1,149 @@
//go:build with_lwip
package tun
import (
"context"
"net"
"net/netip"
"os"
lwip "github.com/sagernet/go-tun2socks/core"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat"
)
type LWIP struct {
ctx context.Context
tun Tun
tunMtu uint32
udpTimeout int64
handler Handler
stack lwip.LWIPStack
udpNat *udpnat.Service[netip.AddrPort]
}
func NewLWIP(
options StackOptions,
) (Stack, error) {
return &LWIP{
ctx: options.Context,
tun: options.Tun,
tunMtu: options.MTU,
handler: options.Handler,
stack: lwip.NewLWIPStack(),
udpNat: udpnat.New[netip.AddrPort](options.UDPTimeout, options.Handler),
}, nil
}
func (l *LWIP) Start() error {
lwip.RegisterTCPConnHandler(l)
lwip.RegisterUDPConnHandler(l)
lwip.RegisterOutputFn(l.tun.Write)
go l.loopIn()
return nil
}
func (l *LWIP) loopIn() {
if winTun, isWintun := l.tun.(WinTun); isWintun {
l.loopInWintun(winTun)
return
}
mtu := int(l.tunMtu) + PacketOffset
_buffer := buf.StackNewSize(mtu)
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
data := buffer.FreeBytes()
for {
n, err := l.tun.Read(data)
if err != nil {
return
}
_, err = l.stack.Write(data[PacketOffset:n])
if err != nil {
if err.Error() == "stack closed" {
return
}
l.handler.NewError(context.Background(), err)
}
}
}
func (l *LWIP) loopInWintun(tun WinTun) {
for {
packet, release, err := tun.ReadPacket()
if err != nil {
return
}
_, err = l.stack.Write(packet)
release()
if err != nil {
if err.Error() == "stack closed" {
return
}
l.handler.NewError(context.Background(), err)
}
}
}
func (l *LWIP) Close() error {
lwip.RegisterTCPConnHandler(nil)
lwip.RegisterUDPConnHandler(nil)
lwip.RegisterOutputFn(func(bytes []byte) (int, error) {
return 0, os.ErrClosed
})
return l.stack.Close()
}
func (l *LWIP) Handle(conn net.Conn) error {
lAddr := conn.LocalAddr()
rAddr := conn.RemoteAddr()
if lAddr == nil || rAddr == nil {
conn.Close()
return nil
}
go func() {
var metadata M.Metadata
metadata.Source = M.SocksaddrFromNet(lAddr)
metadata.Destination = M.SocksaddrFromNet(rAddr)
hErr := l.handler.NewConnection(l.ctx, conn, metadata)
if hErr != nil {
conn.(lwip.TCPConn).Abort()
}
}()
return nil
}
func (l *LWIP) ReceiveTo(conn lwip.UDPConn, data []byte, addr M.Socksaddr) error {
var upstreamMetadata M.Metadata
upstreamMetadata.Source = conn.LocalAddr()
upstreamMetadata.Destination = addr
l.udpNat.NewPacket(
l.ctx,
upstreamMetadata.Source.AddrPort(),
buf.As(data).ToOwned(),
upstreamMetadata,
func(natConn N.PacketConn) N.PacketWriter {
return &LWIPUDPBackWriter{conn}
},
)
return nil
}
type LWIPUDPBackWriter struct {
conn lwip.UDPConn
}
func (w *LWIPUDPBackWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
defer buffer.Release()
return common.Error(w.conn.WriteFrom(buffer.Bytes(), destination))
}
func (w *LWIPUDPBackWriter) Close() error {
return w.conn.Close()
}

11
lwip_stub.go Normal file
View File

@@ -0,0 +1,11 @@
//go:build !with_lwip
package tun
import E "github.com/sagernet/sing/common/exceptions"
func NewLWIP(
options StackOptions,
) (Stack, error) {
return nil, E.New(`LWIP is not included in this build, rebuild with -tags with_lwip`)
}

View File

@@ -1,7 +1,8 @@
package tun
import (
"github.com/sagernet/sing/common/control"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/x/list"
)
@@ -9,23 +10,28 @@ import (
var ErrNoRoute = E.New("no route to internet")
type (
NetworkUpdateCallback = func()
DefaultInterfaceUpdateCallback = func(defaultInterface *control.Interface, flags int)
NetworkUpdateCallback = func() error
DefaultInterfaceUpdateCallback = func(event int) error
)
const FlagAndroidVPNUpdate = 1 << iota
const (
EventInterfaceUpdate = 1
EventAndroidVPNUpdate = 2
)
type NetworkUpdateMonitor interface {
Start() error
Close() error
RegisterCallback(callback NetworkUpdateCallback) *list.Element[NetworkUpdateCallback]
UnregisterCallback(element *list.Element[NetworkUpdateCallback])
E.Handler
}
type DefaultInterfaceMonitor interface {
Start() error
Close() error
DefaultInterface() *control.Interface
DefaultInterfaceName(destination netip.Addr) string
DefaultInterfaceIndex(destination netip.Addr) int
OverrideAndroidVPN() bool
AndroidVPNEnabled() bool
RegisterCallback(callback DefaultInterfaceUpdateCallback) *list.Element[DefaultInterfaceUpdateCallback]
@@ -33,7 +39,6 @@ type DefaultInterfaceMonitor interface {
}
type DefaultInterfaceMonitorOptions struct {
InterfaceFinder control.InterfaceFinder
OverrideAndroidVPN bool
UnderNetworkExtension bool
}

View File

@@ -20,7 +20,7 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
continue
}
vpnEnabled = true
if m.overrideAndroidVPN {
if m.options.OverrideAndroidVPN {
defaultTableIndex = rule.Table
break
}
@@ -51,18 +51,22 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return err
}
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
if err != nil {
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index
var event int
if oldInterface != m.defaultInterfaceName || oldIndex != m.defaultInterfaceIndex {
event |= EventInterfaceUpdate
}
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) && oldVPNEnabled == m.androidVPNEnabled {
return nil
}
var flags int
if oldVPNEnabled != m.androidVPNEnabled {
flags = FlagAndroidVPNUpdate
event |= EventAndroidVPNUpdate
}
m.emit(newInterface, flags)
if event != 0 {
m.emit(event)
}
return nil
}

View File

@@ -1,16 +1,16 @@
package tun
import (
"context"
"net"
"net/netip"
"os"
"sync"
"syscall"
"time"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/net/route"
@@ -18,166 +18,132 @@ import (
)
type networkUpdateMonitor struct {
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
routeSocketFile *os.File
closeOnce sync.Once
done chan struct{}
logger logger.Logger
errorHandler E.Handler
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
routeSocket *os.File
}
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
logger: logger,
done: make(chan struct{}),
errorHandler: errorHandler,
}, nil
}
func (m *networkUpdateMonitor) Start() error {
go m.loopUpdate()
return nil
}
func (m *networkUpdateMonitor) loopUpdate() {
for {
select {
case <-m.done:
return
default:
}
err := m.loopUpdate0()
if err != nil {
m.logger.Error("listen network update: ", err)
return
}
}
}
func (m *networkUpdateMonitor) loopUpdate0() error {
routeSocket, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0)
if err != nil {
return err
}
err = unix.SetNonblock(routeSocket, true)
if err != nil {
unix.Close(routeSocket)
return err
}
routeSocketFile := os.NewFile(uintptr(routeSocket), "route")
defer routeSocketFile.Close()
m.routeSocketFile = routeSocketFile
m.loopUpdate1(routeSocketFile)
m.routeSocket = os.NewFile(uintptr(routeSocket), "route")
go m.loopUpdate()
return nil
}
func (m *networkUpdateMonitor) loopUpdate1(routeSocketFile *os.File) {
buffer := buf.NewPacket()
defer buffer.Release()
done := make(chan struct{})
go func() {
select {
case <-m.done:
routeSocketFile.Close()
case <-done:
}
}()
n, err := routeSocketFile.Read(buffer.FreeBytes())
close(done)
func (m *networkUpdateMonitor) loopUpdate() {
rawConn, err := m.routeSocket.SyscallConn()
if err != nil {
m.errorHandler.NewError(context.Background(), E.Cause(err, "create raw route connection"))
return
}
buffer.Truncate(n)
messages, err := route.ParseRIB(route.RIBTypeRoute, buffer.Bytes())
if err != nil {
return
}
for _, message := range messages {
if _, isRouteMessage := message.(*route.RouteMessage); isRouteMessage {
m.emit()
return
for {
var innerErr error
err = rawConn.Read(func(fd uintptr) (done bool) {
var msg [2048]byte
_, innerErr = unix.Read(int(fd), msg[:])
return innerErr != unix.EWOULDBLOCK
})
if innerErr != nil {
err = innerErr
}
if err != nil {
break
}
m.emit()
}
if err != syscall.EAGAIN {
m.errorHandler.NewError(context.Background(), E.Cause(err, "read route message"))
}
}
func (m *networkUpdateMonitor) Close() error {
m.closeOnce.Do(func() {
close(m.done)
})
return nil
return common.Close(common.PtrOrNil(m.routeSocket))
}
func (m *defaultInterfaceMonitor) checkUpdate() error {
var (
defaultInterface *control.Interface
err error
)
if m.underNetworkExtension {
defaultInterface, err = m.getDefaultInterfaceBySocket()
ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return err
}
routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage)
if err != nil {
return err
}
var defaultInterface *net.Interface
for _, rawRouteMessage := range routeMessages {
routeMessage := rawRouteMessage.(*route.RouteMessage)
if len(routeMessage.Addrs) <= unix.RTAX_NETMASK {
continue
}
destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr)
if !isIPv4Destination {
continue
}
if destination.IP != netip.IPv4Unspecified().As4() {
continue
}
mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr)
if !isIPv4Mask {
continue
}
ones, _ := net.IPMask(mask.IP[:]).Size()
if ones != 0 {
continue
}
routeInterface, err := net.InterfaceByIndex(routeMessage.Index)
if err != nil {
return err
}
} else {
ribMessage, err := route.FetchRIB(unix.AF_UNSPEC, route.RIBTypeRoute, 0)
if err != nil {
return err
if routeMessage.Flags&unix.RTF_UP == 0 {
continue
}
routeMessages, err := route.ParseRIB(route.RIBTypeRoute, ribMessage)
if err != nil {
return err
if routeMessage.Flags&unix.RTF_GATEWAY == 0 {
continue
}
for _, rawRouteMessage := range routeMessages {
routeMessage := rawRouteMessage.(*route.RouteMessage)
if len(routeMessage.Addrs) <= unix.RTAX_NETMASK {
continue
}
destination, isIPv4Destination := routeMessage.Addrs[unix.RTAX_DST].(*route.Inet4Addr)
if !isIPv4Destination {
continue
}
if destination.IP != netip.IPv4Unspecified().As4() {
continue
}
mask, isIPv4Mask := routeMessage.Addrs[unix.RTAX_NETMASK].(*route.Inet4Addr)
if !isIPv4Mask {
continue
}
ones, _ := net.IPMask(mask.IP[:]).Size()
if ones != 0 {
continue
}
routeInterface, err := m.interfaceFinder.ByIndex(routeMessage.Index)
if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
continue
}
defaultInterface = routeInterface
break
}
if defaultInterface == nil {
if m.options.UnderNetworkExtension {
defaultInterface, err = getDefaultInterfaceBySocket()
if err != nil {
return err
}
if routeMessage.Flags&unix.RTF_UP == 0 {
continue
}
if routeMessage.Flags&unix.RTF_GATEWAY == 0 {
continue
}
// if routeMessage.Flags&unix.RTF_IFSCOPE != 0 {
//continue
//}
defaultInterface = routeInterface
break
}
}
if defaultInterface == nil {
return ErrNoRoute
}
newInterface, err := m.interfaceFinder.ByIndex(defaultInterface.Index)
if err != nil {
return E.Cause(err, "find updated interface: ", defaultInterface.Name)
}
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) {
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceIndex = defaultInterface.Index
m.defaultInterfaceName = defaultInterface.Name
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
return nil
}
m.emit(newInterface, 0)
m.emit(EventInterfaceUpdate)
return nil
}
func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interface, error) {
func getDefaultInterfaceBySocket() (*net.Interface, error) {
socketFd, err := unix.Socket(unix.AF_INET, unix.SOCK_STREAM, 0)
if err != nil {
return nil, E.Cause(err, "create file descriptor")
@@ -188,8 +154,6 @@ func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interf
Port: 80,
})
result := make(chan netip.Addr, 1)
done := make(chan struct{})
defer close(done)
go func() {
for {
sockname, sockErr := unix.Getsockname(socketFd)
@@ -202,13 +166,8 @@ func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interf
}
addr := netip.AddrFrom4(sockaddr.Addr)
if addr.IsUnspecified() {
select {
case <-done:
break
default:
time.Sleep(10 * time.Millisecond)
continue
}
time.Sleep(time.Millisecond)
continue
}
result <- addr
break
@@ -218,7 +177,26 @@ func (m *defaultInterfaceMonitor) getDefaultInterfaceBySocket() (*control.Interf
select {
case selectedAddr = <-result:
case <-time.After(time.Second):
return nil, nil
return nil, os.ErrDeadlineExceeded
}
return m.interfaceFinder.ByAddr(selectedAddr)
interfaces, err := net.Interfaces()
if err != nil {
return nil, E.Cause(err, "net.Interfaces")
}
for _, netInterface := range interfaces {
interfaceAddrs, err := netInterface.Addrs()
if err != nil {
return nil, E.Cause(err, "net.Interfaces.Addrs")
}
for _, interfaceAddr := range interfaceAddrs {
ipNet, isIPNet := interfaceAddr.(*net.IPNet)
if !isIPNet {
continue
}
if ipNet.Contains(selectedAddr.AsSlice()) {
return &netInterface, nil
}
}
}
return nil, E.New("no interface found for address ", selectedAddr)
}

View File

@@ -2,56 +2,30 @@ package tun
import (
"os"
"runtime"
"sync"
"time"
"github.com/sagernet/netlink"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/unix"
)
type networkUpdateMonitor struct {
routeUpdate chan netlink.RouteUpdate
linkUpdate chan netlink.LinkUpdate
close chan struct{}
routeUpdate chan netlink.RouteUpdate
linkUpdate chan netlink.LinkUpdate
close chan struct{}
errorHandler E.Handler
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
logger logger.Logger
}
var ErrNetlinkBanned = E.New(
"netlink socket in Android is banned by Google, " +
"use the root or system (ADB) user to run sing-box, " +
"or switch to the sing-box Android graphical interface client",
)
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
monitor := &networkUpdateMonitor{
routeUpdate: make(chan netlink.RouteUpdate, 2),
linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}),
logger: logger,
}
// check is netlink banned by google
if runtime.GOOS == "android" {
netlinkSocket, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_DGRAM, unix.NETLINK_ROUTE)
if err != nil {
return nil, ErrNetlinkBanned
}
err = unix.Bind(netlinkSocket, &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
})
unix.Close(netlinkSocket)
if err != nil {
return nil, ErrNetlinkBanned
}
}
return monitor, nil
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
routeUpdate: make(chan netlink.RouteUpdate, 2),
linkUpdate: make(chan netlink.LinkUpdate, 2),
close: make(chan struct{}),
errorHandler: errorHandler,
}, nil
}
func (m *networkUpdateMonitor) Start() error {
@@ -68,9 +42,6 @@ func (m *networkUpdateMonitor) Start() error {
}
func (m *networkUpdateMonitor) loopUpdate() {
const minDuration = time.Second
timer := time.NewTimer(minDuration)
defer timer.Stop()
for {
select {
case <-m.close:
@@ -79,12 +50,6 @@ func (m *networkUpdateMonitor) loopUpdate() {
case <-m.linkUpdate:
}
m.emit()
select {
case <-m.close:
return
case <-timer.C:
timer.Reset(minDuration)
}
}
}

View File

@@ -25,16 +25,17 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return err
}
newInterface, err := m.interfaceFinder.ByIndex(link.Attrs().Index)
if err != nil {
return E.Cause(err, "find updated interface: ", link.Attrs().Name)
}
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) {
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceName = link.Attrs().Name
m.defaultInterfaceIndex = link.Attrs().Index
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
return nil
}
m.emit(newInterface, 0)
m.emit(EventInterfaceUpdate)
return nil
}
return ErrNoRoute
return E.New("no route to internet")
}

View File

@@ -5,13 +5,13 @@ package tun
import (
"os"
"github.com/sagernet/sing/common/logger"
E "github.com/sagernet/sing/common/exceptions"
)
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return nil, os.ErrInvalid
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
return nil, os.ErrInvalid
}

View File

@@ -3,13 +3,15 @@
package tun
import (
"errors"
"context"
"net"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/x/list"
)
@@ -30,67 +32,87 @@ func (m *networkUpdateMonitor) emit() {
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
callback()
err := callback()
if err != nil {
m.NewError(context.Background(), err)
}
}
}
func (m *networkUpdateMonitor) NewError(ctx context.Context, err error) {
m.errorHandler.NewError(ctx, err)
}
type defaultInterfaceMonitor struct {
interfaceFinder control.InterfaceFinder
overrideAndroidVPN bool
underNetworkExtension bool
defaultInterface atomic.Pointer[control.Interface]
options DefaultInterfaceMonitorOptions
networkAddresses []networkAddress
defaultInterfaceName string
defaultInterfaceIndex int
androidVPNEnabled bool
noRoute bool
networkMonitor NetworkUpdateMonitor
checkUpdateTimer *time.Timer
element *list.Element[NetworkUpdateCallback]
access sync.Mutex
callbacks list.List[DefaultInterfaceUpdateCallback]
logger logger.Logger
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, logger logger.Logger, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
type networkAddress struct {
interfaceName string
interfaceIndex int
addresses []netip.Prefix
}
func NewDefaultInterfaceMonitor(networkMonitor NetworkUpdateMonitor, options DefaultInterfaceMonitorOptions) (DefaultInterfaceMonitor, error) {
return &defaultInterfaceMonitor{
interfaceFinder: options.InterfaceFinder,
overrideAndroidVPN: options.OverrideAndroidVPN,
underNetworkExtension: options.UnderNetworkExtension,
options: options,
networkMonitor: networkMonitor,
logger: logger,
defaultInterfaceIndex: -1,
}, nil
}
func (m *defaultInterfaceMonitor) Start() error {
m.postCheckUpdate()
err := m.checkUpdate()
if err != nil {
m.networkMonitor.NewError(context.Background(), err)
}
m.element = m.networkMonitor.RegisterCallback(m.delayCheckUpdate)
return nil
}
func (m *defaultInterfaceMonitor) delayCheckUpdate() {
if m.checkUpdateTimer == nil {
m.checkUpdateTimer = time.AfterFunc(time.Second, m.postCheckUpdate)
} else {
m.checkUpdateTimer.Reset(time.Second)
func (m *defaultInterfaceMonitor) delayCheckUpdate() error {
time.Sleep(time.Second)
err := m.updateInterfaces()
if err != nil {
m.networkMonitor.NewError(context.Background(), E.Cause(err, "update interfaces"))
}
return m.checkUpdate()
}
func (m *defaultInterfaceMonitor) postCheckUpdate() {
err := m.interfaceFinder.Update()
func (m *defaultInterfaceMonitor) updateInterfaces() error {
interfaces, err := net.Interfaces()
if err != nil {
m.logger.Error("update interface: ", err)
return
return err
}
err = m.checkUpdate()
if errors.Is(err, ErrNoRoute) {
if !m.noRoute {
m.noRoute = true
m.defaultInterface.Store(nil)
m.emit(nil, 0)
var addresses []networkAddress
for _, iif := range interfaces {
var netAddresses []net.Addr
netAddresses, err = iif.Addrs()
if err != nil {
return err
}
} else if err != nil {
m.logger.Error("check interface: ", err)
} else {
m.noRoute = false
var address networkAddress
address.interfaceName = iif.Name
address.interfaceIndex = iif.Index
address.addresses = common.Map(common.FilterIsInstance(netAddresses, func(it net.Addr) (*net.IPNet, bool) {
value, loaded := it.(*net.IPNet)
return value, loaded
}), func(it *net.IPNet) netip.Prefix {
bits, _ := it.Mask.Size()
return netip.PrefixFrom(M.AddrFromIP(it.IP), bits)
})
addresses = append(addresses, address)
}
m.networkAddresses = addresses
return nil
}
func (m *defaultInterfaceMonitor) Close() error {
@@ -100,12 +122,36 @@ func (m *defaultInterfaceMonitor) Close() error {
return nil
}
func (m *defaultInterfaceMonitor) DefaultInterface() *control.Interface {
return m.defaultInterface.Load()
func (m *defaultInterfaceMonitor) DefaultInterfaceName(destination netip.Addr) string {
for _, address := range m.networkAddresses {
for _, prefix := range address.addresses {
if prefix.Contains(destination) {
return address.interfaceName
}
}
}
if m.defaultInterfaceIndex == -1 {
m.checkUpdate()
}
return m.defaultInterfaceName
}
func (m *defaultInterfaceMonitor) DefaultInterfaceIndex(destination netip.Addr) int {
for _, address := range m.networkAddresses {
for _, prefix := range address.addresses {
if prefix.Contains(destination) {
return address.interfaceIndex
}
}
}
if m.defaultInterfaceIndex == -1 {
m.checkUpdate()
}
return m.defaultInterfaceIndex
}
func (m *defaultInterfaceMonitor) OverrideAndroidVPN() bool {
return m.overrideAndroidVPN
return m.options.OverrideAndroidVPN
}
func (m *defaultInterfaceMonitor) AndroidVPNEnabled() bool {
@@ -124,11 +170,14 @@ func (m *defaultInterfaceMonitor) UnregisterCallback(element *list.Element[Defau
m.callbacks.Remove(element)
}
func (m *defaultInterfaceMonitor) emit(defaultInterface *control.Interface, flags int) {
func (m *defaultInterfaceMonitor) emit(event int) {
m.access.Lock()
callbacks := m.callbacks.Array()
m.access.Unlock()
for _, callback := range callbacks {
callback(defaultInterface, flags)
err := callback(event)
if err != nil {
m.networkMonitor.NewError(context.Background(), err)
}
}
}

View File

@@ -5,7 +5,6 @@ import (
"github.com/sagernet/sing-tun/internal/winipcfg"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/x/list"
"golang.org/x/sys/windows"
@@ -14,15 +13,15 @@ import (
type networkUpdateMonitor struct {
routeListener *winipcfg.RouteChangeCallback
interfaceListener *winipcfg.InterfaceChangeCallback
errorHandler E.Handler
access sync.Mutex
callbacks list.List[NetworkUpdateCallback]
logger logger.Logger
}
func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {
func NewNetworkUpdateMonitor(errorHandler E.Handler) (NetworkUpdateMonitor, error) {
return &networkUpdateMonitor{
logger: logger,
errorHandler: errorHandler,
}, nil
}
@@ -77,16 +76,12 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
continue
}
if ifrow.Type == winipcfg.IfTypePropVirtual || ifrow.Type == winipcfg.IfTypeSoftwareLoopback {
continue
}
iface, err := row.InterfaceLUID.IPInterface(windows.AF_INET)
if err != nil {
continue
}
if !iface.Connected {
if ifrow.Type == winipcfg.IfTypePropVirtual || ifrow.Type == winipcfg.IfTypeSoftwareLoopback {
continue
}
@@ -102,14 +97,16 @@ func (m *defaultInterfaceMonitor) checkUpdate() error {
return ErrNoRoute
}
newInterface, err := m.interfaceFinder.ByIndex(index)
if err != nil {
return E.Cause(err, "find updated interface: ", alias)
}
oldInterface := m.defaultInterface.Swap(newInterface)
if oldInterface != nil && oldInterface.Equals(*newInterface) {
oldInterface := m.defaultInterfaceName
oldIndex := m.defaultInterfaceIndex
m.defaultInterfaceName = alias
m.defaultInterfaceIndex = index
if oldInterface == m.defaultInterfaceName && oldIndex == m.defaultInterfaceIndex {
return nil
}
m.emit(newInterface, 0)
m.emit(EventInterfaceUpdate)
return nil
}

View File

@@ -3,21 +3,20 @@ package tun
import (
"strconv"
"github.com/sagernet/sing-tun/internal/gtcpip"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing-tun/internal/clashtcpip"
F "github.com/sagernet/sing/common/format"
N "github.com/sagernet/sing/common/network"
)
func NetworkName(network uint8) string {
switch tcpip.TransportProtocolNumber(network) {
case header.TCPProtocolNumber:
switch network {
case clashtcpip.TCP:
return N.NetworkTCP
case header.UDPProtocolNumber:
case clashtcpip.UDP:
return N.NetworkUDP
case header.ICMPv4ProtocolNumber:
case clashtcpip.ICMP:
return N.NetworkICMPv4
case header.ICMPv6ProtocolNumber:
case clashtcpip.ICMPv6:
return N.NetworkICMPv6
}
return F.ToString(network)
@@ -26,13 +25,13 @@ func NetworkName(network uint8) string {
func NetworkFromName(name string) uint8 {
switch name {
case N.NetworkTCP:
return uint8(header.TCPProtocolNumber)
return clashtcpip.TCP
case N.NetworkUDP:
return uint8(header.UDPProtocolNumber)
return clashtcpip.UDP
case N.NetworkICMPv4:
return uint8(header.ICMPv4ProtocolNumber)
return clashtcpip.ICMP
case N.NetworkICMPv6:
return uint8(header.ICMPv6ProtocolNumber)
return clashtcpip.ICMPv6
}
parseNetwork, err := strconv.ParseUint(name, 10, 8)
if err != nil {

View File

@@ -1,6 +1,6 @@
package tun
import "github.com/sagernet/sing/common/logger"
import E "github.com/sagernet/sing/common/exceptions"
type PackageManager interface {
Start() error
@@ -11,14 +11,7 @@ type PackageManager interface {
SharedPackageByID(id uint32) (string, bool)
}
type PackageManagerOptions struct {
Callback PackageManagerCallback
// Logger is the logger to log errors
// optional
Logger logger.Logger
}
type PackageManagerCallback interface {
OnPackagesUpdated(packages int, sharedUsers int)
E.Handler
}

View File

@@ -2,33 +2,30 @@ package tun
import (
"bytes"
"context"
"encoding/xml"
"io"
"os"
"strconv"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/abx"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/fsnotify/fsnotify"
)
type packageManager struct {
callback PackageManagerCallback
logger logger.Logger
watcher *fswatch.Watcher
watcher *fsnotify.Watcher
idByPackage map[string]uint32
sharedByPackage map[string]uint32
packageById map[uint32]string
sharedById map[uint32]string
}
func NewPackageManager(options PackageManagerOptions) (PackageManager, error) {
return &packageManager{
callback: options.Callback,
logger: options.Logger,
}, nil
func NewPackageManager(callback PackageManagerCallback) (PackageManager, error) {
return &packageManager{callback: callback}, nil
}
func (m *packageManager) Start() error {
@@ -38,33 +35,42 @@ func (m *packageManager) Start() error {
}
err = m.startWatcher()
if err != nil {
m.logger.Error(E.Cause(err, "create watcher for packages list"))
m.callback.NewError(context.Background(), E.Cause(err, "create fsnotify watcher"))
}
return nil
}
func (m *packageManager) startWatcher() error {
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{"/data/system/packages.xml"},
Direct: true,
Callback: m.packagesUpdated,
Logger: m.logger,
})
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
err = watcher.Start()
err = watcher.Add("/data/system/packages.xml")
if err != nil {
return err
}
m.watcher = watcher
go m.loopUpdate()
return nil
}
func (m *packageManager) packagesUpdated(path string) {
err := m.updatePackages()
if err != nil {
m.logger.Error(E.Cause(err, "update packages"))
func (m *packageManager) loopUpdate() {
for {
select {
case _, ok := <-m.watcher.Events:
if !ok {
return
}
err := m.updatePackages()
if err != nil {
m.callback.NewError(context.Background(), E.Cause(err, "update packages"))
}
case err, ok := <-m.watcher.Errors:
if !ok {
return
}
m.callback.NewError(context.Background(), E.Cause(err, "fsnotify error"))
}
}
}

View File

@@ -4,6 +4,6 @@ package tun
import "os"
func NewPackageManager(options PackageManagerOptions) (PackageManager, error) {
func NewPackageManager(callback PackageManagerCallback) (PackageManager, error) {
return nil, os.ErrInvalid
}

View File

@@ -1,36 +0,0 @@
package tun
import (
"context"
"github.com/sagernet/sing/common/control"
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
"go4.org/netipx"
)
const (
DefaultAutoRedirectInputMark = 0x2023
DefaultAutoRedirectOutputMark = 0x2024
)
type AutoRedirect interface {
Start() error
Close() error
UpdateRouteAddressSet()
}
type AutoRedirectOptions struct {
TunOptions *Options
Context context.Context
Handler N.TCPConnectionHandlerEx
Logger logger.Logger
NetworkMonitor NetworkUpdateMonitor
InterfaceFinder control.InterfaceFinder
TableName string
DisableNFTables bool
CustomRedirectPort func() int
RouteAddressSet *[]*netipx.IPSet
RouteExcludeAddressSet *[]*netipx.IPSet
}

View File

@@ -1,81 +0,0 @@
//go:build linux
package tun
import (
"os/exec"
"strings"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
func (r *autoRedirect) setupIPTables() error {
if r.enableIPv4 {
err := r.setupIPTablesForFamily(r.iptablesPath)
if err != nil {
return err
}
}
if r.enableIPv6 {
err := r.setupIPTablesForFamily(r.ip6tablesPath)
if err != nil {
return err
}
}
return nil
}
func (r *autoRedirect) setupIPTablesForFamily(iptablesPath string) error {
tableNameOutput := r.tableName + "-output"
redirectPort := r.redirectPort()
// OUTPUT
err := r.runShell(iptablesPath, "-t nat -N", tableNameOutput)
if err != nil {
return err
}
err = r.runShell(iptablesPath, "-t nat -A", tableNameOutput,
"-p tcp -o", r.tunOptions.Name,
"-j REDIRECT --to-ports", redirectPort)
if err != nil {
return err
}
err = r.runShell(iptablesPath, "-t nat -I OUTPUT -j", tableNameOutput)
if err != nil {
return err
}
return nil
}
func (r *autoRedirect) cleanupIPTables() {
if r.enableIPv4 {
r.cleanupIPTablesForFamily(r.iptablesPath)
}
if r.enableIPv6 {
r.cleanupIPTablesForFamily(r.ip6tablesPath)
}
}
func (r *autoRedirect) cleanupIPTablesForFamily(iptablesPath string) {
tableNameOutput := r.tableName + "-output"
_ = r.runShell(iptablesPath, "-t nat -D OUTPUT -j", tableNameOutput)
_ = r.runShell(iptablesPath, "-t nat -F", tableNameOutput)
_ = r.runShell(iptablesPath, "-t nat -X", tableNameOutput)
}
func (r *autoRedirect) runShell(commands ...any) error {
commandStr := strings.Join(F.MapToString(commands), " ")
var command *exec.Cmd
if r.androidSu {
command = exec.Command(r.suPath, "-c", commandStr)
} else {
commandArray := strings.Split(commandStr, " ")
command = exec.Command(commandArray[0], commandArray[1:]...)
}
combinedOutput, err := command.CombinedOutput()
if err != nil {
return E.Extend(err, F.ToString(commandStr, ": ", string(combinedOutput)))
}
return nil
}

View File

@@ -1,182 +0,0 @@
package tun
import (
"context"
"net/netip"
"os"
"os/exec"
"runtime"
"github.com/sagernet/nftables"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
"go4.org/netipx"
)
type autoRedirect struct {
tunOptions *Options
ctx context.Context
handler N.TCPConnectionHandlerEx
logger logger.Logger
tableName string
networkMonitor NetworkUpdateMonitor
networkListener *list.Element[NetworkUpdateCallback]
interfaceFinder control.InterfaceFinder
localAddresses []netip.Prefix
customRedirectPortFunc func() int
customRedirectPort int
redirectServer *redirectServer
enableIPv4 bool
enableIPv6 bool
iptablesPath string
ip6tablesPath string
useNFTables bool
androidSu bool
suPath string
routeAddressSet *[]*netipx.IPSet
routeExcludeAddressSet *[]*netipx.IPSet
}
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
return &autoRedirect{
tunOptions: options.TunOptions,
ctx: options.Context,
handler: options.Handler,
logger: options.Logger,
networkMonitor: options.NetworkMonitor,
interfaceFinder: options.InterfaceFinder,
tableName: options.TableName,
useNFTables: runtime.GOOS != "android" && !options.DisableNFTables,
customRedirectPortFunc: options.CustomRedirectPort,
routeAddressSet: options.RouteAddressSet,
routeExcludeAddressSet: options.RouteExcludeAddressSet,
}, nil
}
func (r *autoRedirect) Start() error {
var err error
if runtime.GOOS == "android" {
r.enableIPv4 = true
r.iptablesPath = "/system/bin/iptables"
userId := os.Getuid()
if userId != 0 {
r.androidSu = true
for _, suPath := range []string{
"su",
"/system/bin/su",
} {
r.suPath, err = exec.LookPath(suPath)
if err == nil {
break
}
}
if err != nil {
return E.Extend(E.Cause(err, "root permission is required for auto redirect"), os.Getenv("PATH"))
}
}
} else {
if r.useNFTables {
err = r.initializeNFTables()
if err != nil {
return E.Cause(err, "missing nftables support")
}
}
if len(r.tunOptions.Inet4Address) > 0 {
r.enableIPv4 = true
if !r.useNFTables {
r.iptablesPath, err = exec.LookPath("iptables")
if err != nil {
return E.Cause(err, "iptables is required")
}
}
}
if len(r.tunOptions.Inet6Address) > 0 {
r.enableIPv6 = true
if !r.useNFTables {
r.ip6tablesPath, err = exec.LookPath("ip6tables")
if err != nil {
if !r.enableIPv4 {
return E.Cause(err, "ip6tables is required")
} else {
r.enableIPv6 = false
r.logger.Error("device has no ip6tables nat support: ", err)
}
}
}
}
}
if r.customRedirectPortFunc != nil {
r.customRedirectPort = r.customRedirectPortFunc()
}
if r.customRedirectPort == 0 {
var listenAddr netip.Addr
if runtime.GOOS == "android" {
listenAddr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
} else if r.enableIPv6 {
listenAddr = netip.IPv6Unspecified()
} else {
listenAddr = netip.IPv4Unspecified()
}
server := newRedirectServer(r.ctx, r.handler, r.logger, listenAddr)
err := server.Start()
if err != nil {
return E.Cause(err, "start redirect server")
}
r.redirectServer = server
}
if r.useNFTables {
r.cleanupNFTables()
err = r.setupNFTables()
} else {
r.cleanupIPTables()
err = r.setupIPTables()
}
return err
}
func (r *autoRedirect) Close() error {
if r.useNFTables {
r.cleanupNFTables()
} else {
r.cleanupIPTables()
}
return common.Close(
common.PtrOrNil(r.redirectServer),
)
}
func (r *autoRedirect) UpdateRouteAddressSet() {
if r.useNFTables {
err := r.nftablesUpdateRouteAddressSet()
if err != nil {
r.logger.Error("update route address set: ", err)
}
}
}
func (r *autoRedirect) initializeNFTables() error {
nft, err := nftables.New()
if err != nil {
return err
}
defer nft.CloseLasting()
_, err = nft.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil {
return err
}
r.useNFTables = true
return nil
}
func (r *autoRedirect) redirectPort() uint16 {
if r.customRedirectPort > 0 {
return uint16(r.customRedirectPort)
}
return M.AddrPortFromNet(r.redirectServer.listener.Addr()).Port()
}

View File

@@ -1,278 +0,0 @@
//go:build linux
package tun
import (
"net/netip"
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/binaryutil"
"github.com/sagernet/nftables/expr"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/control"
"golang.org/x/exp/slices"
"golang.org/x/sys/unix"
)
func (r *autoRedirect) setupNFTables() error {
nft, err := nftables.New()
if err != nil {
return err
}
defer nft.CloseLasting()
table := nft.AddTable(&nftables.Table{
Name: r.tableName,
Family: nftables.TableFamilyINet,
})
err = r.nftablesCreateAddressSets(nft, table, false)
if err != nil {
return err
}
err = r.interfaceFinder.Update()
if err != nil {
return err
}
r.localAddresses = common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()
})
})
err = r.nftablesCreateLocalAddressSets(nft, table, r.localAddresses, nil)
if err != nil {
return err
}
skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
if !skipOutput {
chainOutput := nft.AddChain(&nftables.Chain{
Name: "output",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeNAT,
})
if r.tunOptions.AutoRedirectMarkMode {
err = r.nftablesCreateExcludeRules(nft, table, chainOutput)
if err != nil {
return err
}
r.nftablesCreateUnreachable(nft, table, chainOutput)
r.nftablesCreateRedirect(nft, table, chainOutput)
chainOutputUDP := nft.AddChain(&nftables.Chain{
Name: "output_udp",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeRoute,
})
err = r.nftablesCreateExcludeRules(nft, table, chainOutputUDP)
if err != nil {
return err
}
r.nftablesCreateUnreachable(nft, table, chainOutputUDP)
r.nftablesCreateMark(nft, table, chainOutputUDP)
} else {
r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
}, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
})
}
}
chainPreRouting := nft.AddChain(&nftables.Chain{
Name: "prerouting",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 1),
Type: nftables.ChainTypeNAT,
})
err = r.nftablesCreateExcludeRules(nft, table, chainPreRouting)
if err != nil {
return err
}
r.nftablesCreateUnreachable(nft, table, chainPreRouting)
r.nftablesCreateRedirect(nft, table, chainPreRouting)
if r.tunOptions.AutoRedirectMarkMode {
r.nftablesCreateMark(nft, table, chainPreRouting)
}
if r.tunOptions.AutoRedirectMarkMode {
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
Name: "prerouting_udp",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 2),
Type: nftables.ChainTypeFilter,
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRoutingUDP,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{unix.IPPROTO_UDP},
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRoutingUDP,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Counter{},
},
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chainPreRoutingUDP,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Counter{},
},
})
}
err = r.configureOpenWRTFirewall4(nft, false)
if err != nil {
return err
}
err = nft.Flush()
if err != nil {
return err
}
r.networkListener = r.networkMonitor.RegisterCallback(func() {
err = r.nftablesUpdateLocalAddressSet()
if err != nil {
r.logger.Error("update local address set: ", err)
}
})
return nil
}
// TODO; test is this works
func (r *autoRedirect) nftablesUpdateLocalAddressSet() error {
newLocalAddresses := common.FlatMap(r.interfaceFinder.Interfaces(), func(it control.Interface) []netip.Prefix {
return common.Filter(it.Addresses, func(prefix netip.Prefix) bool {
return it.Name == "lo" || prefix.Addr().IsGlobalUnicast()
})
})
if slices.Equal(newLocalAddresses, r.localAddresses) {
return nil
}
nft, err := nftables.New()
if err != nil {
return err
}
defer nft.CloseLasting()
table, err := nft.ListTableOfFamily(r.tableName, nftables.TableFamilyINet)
if err != nil {
return err
}
err = r.nftablesCreateLocalAddressSets(nft, table, newLocalAddresses, r.localAddresses)
if err != nil {
return err
}
r.localAddresses = newLocalAddresses
return nft.Flush()
}
func (r *autoRedirect) nftablesUpdateRouteAddressSet() error {
nft, err := nftables.New()
if err != nil {
return err
}
defer nft.CloseLasting()
table, err := nft.ListTableOfFamily(r.tableName, nftables.TableFamilyINet)
if err != nil {
return err
}
err = r.nftablesCreateAddressSets(nft, table, true)
if err != nil {
return err
}
return nft.Flush()
}
func (r *autoRedirect) cleanupNFTables() {
if r.networkListener != nil {
r.networkMonitor.UnregisterCallback(r.networkListener)
}
nft, err := nftables.New()
if err != nil {
return
}
nft.DelTable(&nftables.Table{
Name: r.tableName,
Family: nftables.TableFamilyINet,
})
common.Must(r.configureOpenWRTFirewall4(nft, true))
_ = nft.Flush()
_ = nft.CloseLasting()
}

View File

@@ -1,172 +0,0 @@
//go:build linux
package tun
import (
"net/netip"
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/expr"
"go4.org/netipx"
)
func nftablesIfname(n string) []byte {
b := make([]byte, 16)
copy(b, n+"\x00")
return b
}
func nftablesCreateExcludeDestinationIPSet(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
id uint32, name string, family nftables.TableFamily, invert bool,
) {
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{byte(family)},
},
}
if family == nftables.TableFamilyIPv4 {
exprs = append(exprs,
&expr.Payload{
OperationType: expr.PayloadLoad,
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
)
} else {
exprs = append(exprs,
&expr.Payload{
OperationType: expr.PayloadLoad,
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 24,
Len: 16,
},
)
}
exprs = append(exprs,
&expr.Lookup{
SourceRegister: 1,
SetID: id,
SetName: name,
Invert: invert,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs,
})
}
func nftablesCreateIPSet(
nft *nftables.Conn, table *nftables.Table,
id uint32, name string, family nftables.TableFamily,
setList []*netipx.IPSet, prefixList []netip.Prefix, appendDefault bool, update bool,
) (*nftables.Set, error) {
var builder netipx.IPSetBuilder
for _, prefix := range prefixList {
builder.AddPrefix(prefix)
}
for _, set := range setList {
builder.AddSet(set)
}
ipSet, err := builder.IPSet()
if err != nil {
return nil, err
}
ipRanges := ipSet.Ranges()
setElements := make([]nftables.SetElement, 0, len(ipRanges))
for _, rr := range ipRanges {
if (family == nftables.TableFamilyIPv4) != rr.From().Is4() {
continue
}
endAddr := rr.To().Next()
if !endAddr.IsValid() {
endAddr = rr.From()
}
setElements = append(setElements, nftables.SetElement{
Key: rr.From().AsSlice(),
})
setElements = append(setElements, nftables.SetElement{
Key: endAddr.AsSlice(),
IntervalEnd: true,
})
}
if len(prefixList) == 0 && appendDefault {
if family == nftables.TableFamilyIPv4 {
setElements = append(setElements, nftables.SetElement{
Key: netip.IPv4Unspecified().AsSlice(),
}, nftables.SetElement{
Key: netip.IPv4Unspecified().AsSlice(),
IntervalEnd: true,
})
} else {
setElements = append(setElements, nftables.SetElement{
Key: netip.IPv6Unspecified().AsSlice(),
}, nftables.SetElement{
Key: netip.IPv6Unspecified().AsSlice(),
IntervalEnd: true,
})
}
}
var keyType nftables.SetDatatype
if family == nftables.TableFamilyIPv4 {
keyType = nftables.TypeIPAddr
} else {
keyType = nftables.TypeIP6Addr
}
mySet := &nftables.Set{
Table: table,
ID: id,
Name: name,
Interval: true,
KeyType: keyType,
}
if id == 0 {
mySet.Anonymous = true
mySet.Constant = true
}
if id == 0 {
err := nft.AddSet(mySet, setElements)
if err != nil {
return nil, err
}
return mySet, nil
} else if update {
nft.FlushSet(mySet)
} else {
err := nft.AddSet(mySet, nil)
if err != nil {
return nil, err
}
}
for len(setElements) > 0 {
toAdd := setElements
if len(toAdd) > 1000 {
toAdd = toAdd[:1000]
}
setElements = setElements[len(toAdd):]
err := nft.SetAddElements(mySet, toAdd)
if err != nil {
return nil, err
}
err = nft.Flush()
if err != nil {
return nil, err
}
}
return mySet, nil
}

View File

@@ -1,735 +0,0 @@
//go:build linux
package tun
import (
"net/netip"
_ "unsafe"
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/binaryutil"
"github.com/sagernet/nftables/expr"
"github.com/sagernet/nftables/userdata"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/ranges"
"golang.org/x/exp/slices"
"golang.org/x/sys/unix"
)
//go:linkname allocSetID github.com/sagernet/nftables.allocSetID
var allocSetID uint32
func init() {
allocSetID = 6
}
func (r *autoRedirect) nftablesCreateAddressSets(
nft *nftables.Conn, table *nftables.Table,
update bool,
) error {
routeAddressSet := *r.routeAddressSet
routeExcludeAddressSet := *r.routeExcludeAddressSet
if len(routeAddressSet) == 0 && len(routeExcludeAddressSet) == 0 {
return nil
}
if len(routeAddressSet) > 0 {
if r.enableIPv4 {
_, err := nftablesCreateIPSet(nft, table, 1, "inet4_route_address_set", nftables.TableFamilyIPv4, routeAddressSet, nil, true, update)
if err != nil {
return err
}
}
if r.enableIPv6 {
_, err := nftablesCreateIPSet(nft, table, 2, "inet6_route_address_set", nftables.TableFamilyIPv6, routeAddressSet, nil, true, update)
if err != nil {
return err
}
}
}
if len(routeExcludeAddressSet) > 0 {
if r.enableIPv4 {
_, err := nftablesCreateIPSet(nft, table, 3, "inet4_route_exclude_address_set", nftables.TableFamilyIPv4, routeExcludeAddressSet, nil, false, update)
if err != nil {
return err
}
}
if r.enableIPv6 {
_, err := nftablesCreateIPSet(nft, table, 4, "inet6_route_exclude_address_set", nftables.TableFamilyIPv6, routeExcludeAddressSet, nil, false, update)
if err != nil {
return err
}
}
}
return nil
}
func (r *autoRedirect) nftablesCreateLocalAddressSets(
nft *nftables.Conn, table *nftables.Table,
localAddresses []netip.Prefix, lastAddresses []netip.Prefix,
) error {
if r.enableIPv4 {
localAddresses4 := common.Filter(localAddresses, func(it netip.Prefix) bool {
return it.Addr().Is4()
})
updateAddresses4 := common.Filter(localAddresses, func(it netip.Prefix) bool {
return it.Addr().Is4()
})
var update bool
if len(lastAddresses) != 0 {
if !slices.Equal(localAddresses4, updateAddresses4) {
update = true
}
}
if len(lastAddresses) == 0 || update {
_, err := nftablesCreateIPSet(nft, table, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, nil, localAddresses4, false, update)
if err != nil {
return err
}
}
}
if r.enableIPv6 {
localAddresses6 := common.Filter(localAddresses, func(it netip.Prefix) bool {
return it.Addr().Is6()
})
updateAddresses6 := common.Filter(localAddresses, func(it netip.Prefix) bool {
return it.Addr().Is6()
})
var update bool
if len(lastAddresses) != 0 {
if !slices.Equal(localAddresses6, updateAddresses6) {
update = true
}
}
localAddresses6 = common.Filter(localAddresses6, func(it netip.Prefix) bool {
address := it.Addr()
return address.IsLoopback() || address.IsGlobalUnicast() && !address.IsPrivate()
})
if len(lastAddresses) == 0 || update {
_, err := nftablesCreateIPSet(nft, table, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, nil, localAddresses6, false, update)
if err != nil {
return err
}
}
}
return nil
}
func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error {
if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
if chain.Type == nftables.ChainTypeRoute {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
}
if chain.Hooknum == nftables.ChainHookPrerouting {
if len(r.tunOptions.IncludeInterface) > 0 {
if len(r.tunOptions.IncludeInterface) > 1 {
includeInterface := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeIFName,
}
err := nft.AddSet(includeInterface, common.Map(r.tunOptions.IncludeInterface, func(it string) nftables.SetElement {
return nftables.SetElement{
Key: nftablesIfname(it),
}
}))
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Lookup{
SourceRegister: 1,
SetID: includeInterface.ID,
SetName: includeInterface.Name,
Invert: true,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: nftablesIfname(r.tunOptions.IncludeInterface[0]),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
}
if len(r.tunOptions.ExcludeInterface) > 0 {
if len(r.tunOptions.ExcludeInterface) > 1 {
excludeInterface := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeIFName,
}
err := nft.AddSet(excludeInterface, common.Map(r.tunOptions.ExcludeInterface, func(it string) nftables.SetElement {
return nftables.SetElement{
Key: nftablesIfname(it),
}
}))
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Lookup{
SourceRegister: 1,
SetID: excludeInterface.ID,
SetName: excludeInterface.Name,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.ExcludeInterface[0]),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
}
} else {
if len(r.tunOptions.IncludeUID) > 0 {
if len(r.tunOptions.IncludeUID) > 1 || r.tunOptions.IncludeUID[0].Start != r.tunOptions.IncludeUID[0].End {
includeUID := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
Interval: true,
KeyType: nftables.TypeUID,
}
err := nft.AddSet(includeUID, common.FlatMap(r.tunOptions.IncludeUID, func(it ranges.Range[uint32]) []nftables.SetElement {
return []nftables.SetElement{
{
Key: binaryutil.NativeEndian.PutUint32(it.Start),
},
{
Key: binaryutil.NativeEndian.PutUint32(it.End + 1),
IntervalEnd: true,
},
}
}))
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
&expr.Lookup{
SourceRegister: 1,
SetID: includeUID.ID,
SetName: includeUID.Name,
Invert: true,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
UserData: userdata.AppendString(nil, userdata.TypeComment, "not a bug :("),
})
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.BigEndian.PutUint32(r.tunOptions.IncludeUID[0].Start),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
}
if len(r.tunOptions.ExcludeUID) > 0 {
if len(r.tunOptions.ExcludeUID) > 1 || r.tunOptions.ExcludeUID[0].Start != r.tunOptions.ExcludeUID[0].End {
excludeUID := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
Interval: true,
KeyType: nftables.TypeUID,
}
err := nft.AddSet(excludeUID, common.FlatMap(r.tunOptions.ExcludeUID, func(it ranges.Range[uint32]) []nftables.SetElement {
return []nftables.SetElement{
{
Key: binaryutil.NativeEndian.PutUint32(it.Start),
},
{
Key: binaryutil.NativeEndian.PutUint32(it.End + 1),
IntervalEnd: true,
},
}
}))
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
&expr.Lookup{
SourceRegister: 1,
SetID: excludeUID.ID,
SetName: excludeUID.Name,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
UserData: userdata.AppendString(nil, userdata.TypeComment, "not a bug :("),
})
} else {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeySKUID, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.ExcludeUID[0].Start),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
}
}
if len(r.tunOptions.Inet4RouteAddress) > 0 {
inet4RouteAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv4, nil, r.tunOptions.Inet4RouteAddress, false, false)
if err != nil {
return err
}
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet4RouteAddress.ID, inet4RouteAddress.Name, nftables.TableFamilyIPv4, true)
}
if len(r.tunOptions.Inet6RouteAddress) > 0 {
inet6RouteAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv6, nil, r.tunOptions.Inet6RouteAddress, false, false)
if err != nil {
return err
}
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet6RouteAddress.ID, inet6RouteAddress.Name, nftables.TableFamilyIPv6, true)
}
if len(r.tunOptions.Inet4RouteExcludeAddress) > 0 {
inet4RouteExcludeAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv4, nil, r.tunOptions.Inet4RouteExcludeAddress, false, false)
if err != nil {
return err
}
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet4RouteExcludeAddress.ID, inet4RouteExcludeAddress.Name, nftables.TableFamilyIPv4, false)
}
if len(r.tunOptions.Inet6RouteExcludeAddress) > 0 {
inet6RouteExcludeAddress, err := nftablesCreateIPSet(nft, table, 0, "", nftables.TableFamilyIPv6, nil, r.tunOptions.Inet6RouteExcludeAddress, false, false)
if err != nil {
return err
}
nftablesCreateExcludeDestinationIPSet(nft, table, chain, inet6RouteExcludeAddress.ID, inet6RouteExcludeAddress.Name, nftables.TableFamilyIPv6, false)
}
if !r.tunOptions.EXP_DisableDNSHijack && ((chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeNAT) ||
(r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT)) {
if r.enableIPv4 {
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv4, 5, "inet4_local_address_set")
if err != nil {
return err
}
}
if r.enableIPv6 {
err := r.nftablesCreateDNSHijackRulesForFamily(nft, table, chain, nftables.TableFamilyIPv6, 6, "inet6_local_address_set")
if err != nil {
return err
}
}
}
if r.tunOptions.AutoRedirectMarkMode &&
((chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeRoute) ||
(chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeFilter)) {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{unix.IPPROTO_UDP},
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
if r.enableIPv4 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false)
}
if r.enableIPv6 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 6, "inet6_local_address_set", nftables.TableFamilyIPv6, false)
}
routeAddressSet := *r.routeAddressSet
routeExcludeAddressSet := *r.routeExcludeAddressSet
if r.enableIPv4 && len(routeAddressSet) > 0 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 1, "inet4_route_address_set", nftables.TableFamilyIPv4, true)
}
if r.enableIPv6 && len(routeAddressSet) > 0 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 2, "inet6_route_address_set", nftables.TableFamilyIPv6, true)
}
if r.enableIPv4 && len(routeExcludeAddressSet) > 0 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 3, "inet4_route_exclude_address_set", nftables.TableFamilyIPv4, false)
}
if r.enableIPv6 && len(routeExcludeAddressSet) > 0 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 4, "inet6_route_exclude_address_set", nftables.TableFamilyIPv6, false)
}
return nil
}
func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
}, // output meta mark set myMark ct mark set meta mark
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Counter{},
},
})
}
func (r *autoRedirect) nftablesCreateRedirect(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
exprs ...expr.Any,
) {
if r.enableIPv4 && !r.enableIPv6 {
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv4)},
})
} else if !r.enableIPv4 && r.enableIPv6 {
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv6)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs,
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Counter{},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
},
&expr.Redir{
RegisterProtoMin: 1,
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
),
})
}
func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
family nftables.TableFamily, setID uint32, setName string,
) error {
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err := nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_TCP}},
{Key: []byte{unix.IPPROTO_UDP}},
})
if err != nil {
return err
}
dnsServer := common.Find(r.tunOptions.DNSServers, func(it netip.Addr) bool {
return it.Is4() == (family == nftables.TableFamilyIPv4)
})
if !dnsServer.IsValid() {
if family == nftables.TableFamilyIPv4 {
if HasNextAddress(r.tunOptions.Inet4Address[0], 1) {
dnsServer = r.tunOptions.Inet4Address[0].Addr().Next()
}
} else {
if HasNextAddress(r.tunOptions.Inet6Address[0], 1) {
dnsServer = r.tunOptions.Inet6Address[0].Addr().Next()
}
}
}
if !dnsServer.IsValid() {
return nil
}
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(family)},
},
}
if chain.Hooknum == nftables.ChainHookOutput {
// It looks like we can't hijack DNS requests sent to loopback.
// https://serverfault.com/questions/363899/iptables-dnat-from-loopback
// and tproxy is not available in output
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: nftablesIfname("lo"),
},
)
} else {
if family == nftables.TableFamilyIPv4 {
exprs = append(exprs,
&expr.Payload{
OperationType: expr.PayloadLoad,
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
)
} else {
exprs = append(exprs,
&expr.Payload{
OperationType: expr.PayloadLoad,
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 8,
Len: 16,
},
)
}
exprs = append(exprs, &expr.Lookup{
SourceRegister: 1,
SetID: setID,
SetName: setName,
})
}
exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
},
&expr.Payload{
OperationType: expr.PayloadLoad,
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(53),
},
&expr.Immediate{
Register: 1,
Data: dnsServer.AsSlice(),
},
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(family),
RegAddrMin: 1,
},
&expr.Counter{},
)
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs,
})
return nil
}
func (r *autoRedirect) nftablesCreateUnreachable(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
) {
if (r.enableIPv4 && r.enableIPv6) || !r.tunOptions.StrictRoute {
return
}
var nfProto nftables.TableFamily
if r.enableIPv4 {
nfProto = nftables.TableFamilyIPv6
} else {
nfProto = nftables.TableFamilyIPv4
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nfProto)},
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictDrop,
},
},
})
}

View File

@@ -1,103 +0,0 @@
//go:build linux
package tun
import (
"github.com/sagernet/nftables"
"github.com/sagernet/nftables/expr"
"golang.org/x/exp/slices"
)
func (r *autoRedirect) configureOpenWRTFirewall4(nft *nftables.Conn, cleanup bool) error {
tableFW4, err := nft.ListTableOfFamily("fw4", nftables.TableFamilyINet)
if err != nil {
return nil
}
if !cleanup {
ruleIif := &nftables.Rule{
Table: tableFW4,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
ruleOif := &nftables.Rule{
Table: tableFW4,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
}
chainForward := &nftables.Chain{
Name: "forward",
}
ruleIif.Chain = chainForward
ruleOif.Chain = chainForward
nft.InsertRule(ruleOif)
nft.InsertRule(ruleIif)
chainInput := &nftables.Chain{
Name: "input",
}
ruleIif.Chain = chainInput
ruleOif.Chain = chainInput
nft.InsertRule(ruleOif)
nft.InsertRule(ruleIif)
return nil
}
for _, chainName := range []string{"input", "forward"} {
var rules []*nftables.Rule
rules, err = nft.GetRules(tableFW4, &nftables.Chain{
Name: chainName,
})
if err != nil {
return err
}
for _, rule := range rules {
if len(rule.Exprs) != 4 {
continue
}
exprMeta, isMeta := rule.Exprs[0].(*expr.Meta)
if !isMeta {
continue
}
if exprMeta.Key != expr.MetaKeyIIFNAME && exprMeta.Key != expr.MetaKeyOIFNAME {
continue
}
exprCmp, isCmp := rule.Exprs[1].(*expr.Cmp)
if !isCmp {
continue
}
if !slices.Equal(exprCmp.Data, nftablesIfname(r.tunOptions.Name)) {
continue
}
err = nft.DelRule(rule)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -1,83 +0,0 @@
//go:build linux
package tun
import (
"context"
"errors"
"net"
"net/netip"
"time"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type redirectServer struct {
ctx context.Context
handler N.TCPConnectionHandlerEx
logger logger.Logger
listenAddr netip.Addr
listener *net.TCPListener
inShutdown atomic.Bool
}
func newRedirectServer(ctx context.Context, handler N.TCPConnectionHandlerEx, logger logger.Logger, listenAddr netip.Addr) *redirectServer {
return &redirectServer{
ctx: ctx,
handler: handler,
logger: logger,
listenAddr: listenAddr,
}
}
func (s *redirectServer) Start() error {
var listenConfig net.ListenConfig
// listenConfig.KeepAlive = C.TCPKeepAliveInitial
listenConfig.KeepAlive = 10 * time.Minute
listener, err := listenConfig.Listen(s.ctx, M.NetworkFromNetAddr("tcp", s.listenAddr), M.SocksaddrFrom(s.listenAddr, 0).String())
if err != nil {
return err
}
s.listener = listener.(*net.TCPListener)
go s.loopIn()
return nil
}
func (s *redirectServer) Close() error {
s.inShutdown.Store(true)
return s.listener.Close()
}
func (s *redirectServer) loopIn() {
for {
conn, err := s.listener.AcceptTCP()
if err != nil {
var netError net.Error
//nolint:staticcheck
if errors.As(err, &netError) && netError.Temporary() {
s.logger.Error(err)
continue
}
if s.inShutdown.Load() && E.IsClosed(err) {
return
}
s.listener.Close()
s.logger.Error("serve error: ", err)
continue
}
source := M.SocksaddrFromNet(conn.RemoteAddr()).Unwrap()
destination, err := control.GetOriginalDestination(conn)
if err != nil {
_ = conn.SetLinger(0)
_ = conn.Close()
s.logger.Error("process redirect connection from ", source, ": invalid connection: ", err)
continue
}
go s.handler.NewConnectionEx(s.ctx, conn, source, M.SocksaddrFromNetIP(destination).Unwrap(), nil)
}
}

View File

@@ -1,11 +0,0 @@
//go:build !linux
package tun
import (
"os"
)
func NewAutoRedirect(options AutoRedirectOptions) (AutoRedirect, error) {
return nil, os.ErrInvalid
}

View File

@@ -2,18 +2,13 @@ package tun
import (
"context"
"encoding/binary"
"net"
"net/netip"
"time"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var ErrDrop = E.New("drop connections by rule")
type Stack interface {
Start() error
Close() error
@@ -22,12 +17,15 @@ type Stack interface {
type StackOptions struct {
Context context.Context
Tun Tun
TunOptions Options
UDPTimeout time.Duration
Name string
MTU uint32
Inet4Address []netip.Prefix
Inet6Address []netip.Prefix
EndpointIndependentNat bool
UDPTimeout int64
Handler Handler
Logger logger.Logger
ForwarderBindInterface bool
IncludeAllNetworks bool
InterfaceFinder control.InterfaceFinder
}
@@ -37,44 +35,14 @@ func NewStack(
) (Stack, error) {
switch stack {
case "":
if options.IncludeAllNetworks {
return NewGVisor(options)
} else if WithGVisor && !options.TunOptions.GSO {
return NewMixed(options)
} else {
return NewSystem(options)
}
return NewSystem(options)
case "gvisor":
return NewGVisor(options)
case "mixed":
if options.IncludeAllNetworks {
return nil, ErrIncludeAllNetworks
}
return NewMixed(options)
case "system":
if options.IncludeAllNetworks {
return nil, ErrIncludeAllNetworks
}
return NewSystem(options)
case "lwip":
return NewLWIP(options)
default:
return nil, E.New("unknown stack: ", stack)
}
}
func HasNextAddress(prefix netip.Prefix, count int) bool {
checkAddr := prefix.Addr()
for i := 0; i < count; i++ {
checkAddr = checkAddr.Next()
}
return prefix.Contains(checkAddr)
}
func BroadcastAddr(inet4Address []netip.Prefix) netip.Addr {
if len(inet4Address) == 0 {
return netip.Addr{}
}
prefix := inet4Address[0]
var broadcastAddr [4]byte
binary.BigEndian.PutUint32(broadcastAddr[:], binary.BigEndian.Uint32(prefix.Masked().Addr().AsSlice())|^binary.BigEndian.Uint32(net.CIDRMask(prefix.Bits(), 32)))
return netip.AddrFrom4(broadcastAddr)
}

View File

@@ -1,164 +0,0 @@
//go:build with_gvisor
package tun
import (
"context"
"net/netip"
"runtime"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/icmp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
const WithGVisor = true
const DefaultNIC tcpip.NICID = 1
type GVisor struct {
ctx context.Context
tun GVisorTun
udpTimeout time.Duration
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
}
type GVisorTun interface {
Tun
NewEndpoint() (stack.LinkEndpoint, error)
}
func NewGVisor(
options StackOptions,
) (Stack, error) {
gTun, isGTun := options.Tun.(GVisorTun)
if !isGTun {
return nil, E.New("gVisor stack is unsupported on current platform")
}
gStack := &GVisor{
ctx: options.Context,
tun: gTun,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
return gStack, nil
}
func (t *GVisor) Start() error {
linkEndpoint, err := t.tun.NewEndpoint()
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun}
ipStack, err := NewGVisorStack(linkEndpoint)
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack
t.endpoint = linkEndpoint
return nil
}
func (t *GVisor) Close() error {
if t.stack == nil {
return nil
}
t.endpoint.Attach(nil)
t.stack.Close()
for _, endpoint := range t.stack.CleanupEndpoints() {
endpoint.Abort()
}
return nil
}
func AddressFromAddr(destination netip.Addr) tcpip.Address {
if destination.Is6() {
return tcpip.AddrFrom16(destination.As16())
} else {
return tcpip.AddrFrom4(destination.As4())
}
}
func AddrFromAddress(address tcpip.Address) netip.Addr {
if address.Len() == 16 {
return netip.AddrFrom16(address.As16())
} else {
return netip.AddrFrom4(address.As4())
}
}
func NewGVisorStack(ep stack.LinkEndpoint) (*stack.Stack, error) {
ipStack := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{
ipv4.NewProtocol,
ipv6.NewProtocol,
},
TransportProtocols: []stack.TransportProtocolFactory{
tcp.NewProtocol,
udp.NewProtocol,
icmp.NewProtocol4,
icmp.NewProtocol6,
},
})
err := ipStack.CreateNIC(DefaultNIC, ep)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
ipStack.SetRouteTable([]tcpip.Route{
{Destination: header.IPv4EmptySubnet, NIC: DefaultNIC},
{Destination: header.IPv6EmptySubnet, NIC: DefaultNIC},
})
err = ipStack.SetSpoofing(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
err = ipStack.SetPromiscuousMode(DefaultNIC, true)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
sOpt := tcpip.TCPSACKEnabled(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt)
mOpt := tcpip.TCPModerateReceiveBufferOption(true)
ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &mOpt)
if runtime.GOOS == "windows" {
tcpRecoveryOpt := tcpip.TCPRecovery(0)
err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt)
}
tcpRXBufOpt := tcpip.TCPReceiveBufferSizeRangeOption{
Min: tcpRXBufMinSize,
Default: tcpRXBufDefSize,
Max: tcpRXBufMaxSize,
}
err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRXBufOpt)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
tcpTXBufOpt := tcpip.TCPSendBufferSizeRangeOption{
Min: tcpTXBufMinSize,
Default: tcpTXBufDefSize,
Max: tcpTXBufMaxSize,
}
err = ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpTXBufOpt)
if err != nil {
return nil, gonet.TranslateNetstackError(err)
}
return ipStack, nil
}

View File

@@ -1,56 +0,0 @@
//go:build with_gvisor
package tun
import (
"net/netip"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)
var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)
type LinkEndpointFilter struct {
stack.LinkEndpoint
BroadcastAddress netip.Addr
Writer N.VectorisedWriter
}
func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.BroadcastAddress, w.Writer})
}
var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)
type networkDispatcherFilter struct {
stack.NetworkDispatcher
broadcastAddress netip.Addr
writer N.VectorisedWriter
}
func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) {
var network header.Network
if protocol == header.IPv4ProtocolNumber {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv4MinimumSize); loaded {
network = header.IPv4(headerPackets)
}
} else {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv6MinimumSize); loaded {
network = header.IPv6(headerPackets)
}
}
if network == nil {
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
return
}
destination := AddrFromAddress(network.DestinationAddress())
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
return
}
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
}

View File

@@ -1,191 +0,0 @@
//go:build with_gvisor
package tun
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/waiter"
"github.com/sagernet/sing/common"
)
type gLazyConn struct {
tcpConn *gonet.TCPConn
parentCtx context.Context
stack *stack.Stack
request *tcp.ForwarderRequest
localAddr net.Addr
remoteAddr net.Addr
handshakeDone bool
handshakeErr error
}
func (c *gLazyConn) HandshakeContext(ctx context.Context) error {
if c.handshakeDone {
return nil
}
defer func() {
c.handshakeDone = true
}()
var (
wq waiter.Queue
endpoint tcpip.Endpoint
)
handshakeCtx, cancel := context.WithCancel(ctx)
go func() {
select {
case <-c.parentCtx.Done():
wq.Notify(wq.Events())
case <-handshakeCtx.Done():
}
}()
endpoint, err := c.request.CreateEndpoint(&wq)
cancel()
if err != nil {
gErr := gonet.TranslateNetstackError(err)
c.handshakeErr = gErr
c.request.Complete(true)
return gErr
}
c.request.Complete(false)
endpoint.SocketOptions().SetKeepAlive(true)
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIdleOption(15 * time.Second)))
endpoint.SetSockOpt(common.Ptr(tcpip.KeepaliveIntervalOption(15 * time.Second)))
tcpConn := gonet.NewTCPConn(&wq, endpoint)
c.tcpConn = tcpConn
return nil
}
func (c *gLazyConn) HandshakeFailure(err error) error {
if c.handshakeDone {
return os.ErrInvalid
}
c.request.Complete(err != ErrDrop)
c.handshakeDone = true
c.handshakeErr = err
return nil
}
func (c *gLazyConn) HandshakeSuccess() error {
return c.HandshakeContext(context.Background())
}
func (c *gLazyConn) Read(b []byte) (n int, err error) {
if !c.handshakeDone {
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
} else if c.handshakeErr != nil {
return 0, c.handshakeErr
}
return c.tcpConn.Read(b)
}
func (c *gLazyConn) Write(b []byte) (n int, err error) {
if !c.handshakeDone {
err = c.HandshakeContext(context.Background())
if err != nil {
return
}
} else if c.handshakeErr != nil {
return 0, c.handshakeErr
}
return c.tcpConn.Write(b)
}
func (c *gLazyConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *gLazyConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *gLazyConn) SetDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
}
return c.tcpConn.SetDeadline(t)
}
func (c *gLazyConn) SetReadDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
}
return c.tcpConn.SetReadDeadline(t)
}
func (c *gLazyConn) SetWriteDeadline(t time.Time) error {
if !c.handshakeDone {
err := c.HandshakeContext(context.Background())
if err != nil {
return err
}
} else if c.handshakeErr != nil {
return c.handshakeErr
}
return c.tcpConn.SetWriteDeadline(t)
}
func (c *gLazyConn) Close() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
} else if c.handshakeErr != nil {
return nil
}
return c.tcpConn.Close()
}
func (c *gLazyConn) CloseRead() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
} else if c.handshakeErr != nil {
return nil
}
return c.tcpConn.CloseRead()
}
func (c *gLazyConn) CloseWrite() error {
if !c.handshakeDone {
c.request.Complete(true)
c.handshakeErr = net.ErrClosed
return nil
} else if c.handshakeErr != nil {
return nil
}
return c.tcpConn.CloseRead()
}
func (c *gLazyConn) ReaderReplaceable() bool {
return c.handshakeDone && c.handshakeErr == nil
}
func (c *gLazyConn) WriterReplaceable() bool {
return c.handshakeDone && c.handshakeErr == nil
}
func (c *gLazyConn) Upstream() any {
return c.tcpConn
}

View File

@@ -1,51 +0,0 @@
//go:build with_gvisor
package tun
import (
"context"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type TCPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
forwarder *tcp.Forwarder
}
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
forwarder := &TCPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
}
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder
}
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
return f.forwarder.HandlePacket(id, pkt)
}
func (f *TCPForwarder) Forward(r *tcp.ForwarderRequest) {
source := M.SocksaddrFrom(AddrFromAddress(r.ID().RemoteAddress), r.ID().RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(r.ID().LocalAddress), r.ID().LocalPort)
pErr := f.handler.PrepareConnection(N.NetworkTCP, source, destination)
if pErr != nil {
r.Complete(pErr != ErrDrop)
return
}
conn := &gLazyConn{
parentCtx: f.ctx,
stack: f.stack,
request: r,
localAddr: source.TCPAddr(),
remoteAddr: destination.TCPAddr(),
}
go f.handler.NewConnectionEx(f.ctx, conn, source, destination, nil)
}

View File

@@ -1,18 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build with_gvisor && !ios
package tun
import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
const (
tcpRXBufMinSize = tcp.MinBufferSize
tcpRXBufDefSize = tcp.DefaultSendBufferSize
tcpRXBufMaxSize = 8 << 20 // 8MiB
tcpTXBufMinSize = tcp.MinBufferSize
tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
tcpTXBufMaxSize = 6 << 20 // 6MiB
)

View File

@@ -1,21 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build with_gvisor
package tun
import "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
const (
// tcp{RX,TX}Buf{Min,Def,Max}Size mirror gVisor defaults. We leave these
// unchanged on iOS for now as to not increase pressure towards the
// NetworkExtension memory limit.
tcpRXBufMinSize = tcp.MinBufferSize
tcpRXBufDefSize = tcp.DefaultSendBufferSize
tcpRXBufMaxSize = tcp.MaxBufferSize
tcpTXBufMinSize = tcp.MinBufferSize
tcpTXBufDefSize = tcp.DefaultReceiveBufferSize
tcpTXBufMaxSize = tcp.MaxBufferSize
)

View File

@@ -1,183 +0,0 @@
//go:build with_gvisor
package tun
import (
"context"
"math"
"net/netip"
"os"
"sync"
"time"
_ "unsafe"
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/checksum"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat2"
)
type UDPForwarder struct {
ctx context.Context
stack *stack.Stack
handler Handler
udpNat *udpnat.Service
}
func NewUDPForwarder(ctx context.Context, stack *stack.Stack, handler Handler, timeout time.Duration) *UDPForwarder {
forwarder := &UDPForwarder{
ctx: ctx,
stack: stack,
handler: handler,
}
forwarder.udpNat = udpnat.New(handler, forwarder.PreparePacketConnection, timeout, true)
return forwarder
}
func (f *UDPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
source := M.SocksaddrFrom(AddrFromAddress(id.RemoteAddress), id.RemotePort)
destination := M.SocksaddrFrom(AddrFromAddress(id.LocalAddress), id.LocalPort)
bufferRange := pkt.Data().AsRange()
bufferSlices := make([][]byte, bufferRange.Size())
rangeIterate(bufferRange, func(view *buffer.View) {
bufferSlices = append(bufferSlices, view.AsSlice())
})
f.udpNat.NewPacket(bufferSlices, source, destination, pkt)
return true
}
//go:linkname rangeIterate github.com/sagernet/gvisor/pkg/tcpip/stack.Range.iterate
func rangeIterate(r stack.Range, fn func(*buffer.View))
func (f *UDPForwarder) PreparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := f.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
if pErr != ErrDrop {
gWriteUnreachable(f.stack, userData.(*stack.PacketBuffer))
}
return false, nil, nil, nil
}
var sourceNetwork tcpip.NetworkProtocolNumber
if source.Addr.Is4() {
sourceNetwork = header.IPv4ProtocolNumber
} else {
sourceNetwork = header.IPv6ProtocolNumber
}
writer := &UDPBackWriter{
stack: f.stack,
packet: userData.(*stack.PacketBuffer).IncRef(),
source: AddressFromAddr(source.Addr),
sourcePort: source.Port,
sourceNetwork: sourceNetwork,
}
return true, f.ctx, writer, nil
}
type UDPBackWriter struct {
access sync.Mutex
stack *stack.Stack
packet *stack.PacketBuffer
source tcpip.Address
sourcePort uint16
sourceNetwork tcpip.NetworkProtocolNumber
}
func (w *UDPBackWriter) HandshakeSuccess() error {
w.access.Lock()
defer w.access.Unlock()
if w.packet != nil {
w.packet.DecRef()
w.packet = nil
}
return nil
}
func (w *UDPBackWriter) HandshakeFailure(err error) error {
w.access.Lock()
defer w.access.Unlock()
if w.packet == nil {
return os.ErrInvalid
}
wErr := gWriteUnreachable(w.stack, w.packet)
w.packet.DecRef()
w.packet = nil
return wErr
}
func (w *UDPBackWriter) WritePacket(packetBuffer *buf.Buffer, destination M.Socksaddr) error {
if !destination.IsIP() {
return E.Cause(os.ErrInvalid, "invalid destination")
} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4ProtocolNumber) {
return E.New("send IPv6 packet to IPv4 connection")
}
defer packetBuffer.Release()
route, err := w.stack.FindRoute(
DefaultNIC,
AddressFromAddr(destination.Addr),
w.source,
w.sourceNetwork,
false,
)
if err != nil {
return gonet.TranslateNetstackError(err)
}
defer route.Release()
packet := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: header.UDPMinimumSize + int(route.MaxHeaderLength()),
Payload: buffer.MakeWithData(packetBuffer.Bytes()),
})
defer packet.DecRef()
packet.TransportProtocolNumber = header.UDPProtocolNumber
udpHdr := header.UDP(packet.TransportHeader().Push(header.UDPMinimumSize))
pLen := uint16(packet.Size())
udpHdr.Encode(&header.UDPFields{
SrcPort: destination.Port,
DstPort: w.sourcePort,
Length: pLen,
})
if route.RequiresTXTransportChecksum() && w.sourceNetwork == header.IPv6ProtocolNumber {
xsum := udpHdr.CalculateChecksum(checksum.Combine(
route.PseudoHeaderChecksum(header.UDPProtocolNumber, pLen),
packet.Data().Checksum(),
))
if xsum != math.MaxUint16 {
xsum = ^xsum
}
udpHdr.SetChecksum(xsum)
}
err = route.WritePacket(stack.NetworkHeaderParams{
Protocol: header.UDPProtocolNumber,
TTL: route.DefaultTTL(),
TOS: 0,
}, packet)
if err != nil {
route.Stats().UDP.PacketSendErrors.Increment()
return gonet.TranslateNetstackError(err)
}
route.Stats().UDP.PacketsSent.Increment()
return nil
}
func gWriteUnreachable(gStack *stack.Stack, packet *stack.PacketBuffer) error {
if packet.NetworkProtocolNumber == header.IPv4ProtocolNumber {
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv4ProtocolNumber).(stack.RejectIPv4WithHandler).SendRejectionError(packet, stack.RejectIPv4WithICMPPortUnreachable, true))
} else {
return gonet.TranslateNetstackError(gStack.NetworkProtocolInstance(header.IPv6ProtocolNumber).(stack.RejectIPv6WithHandler).SendRejectionError(packet, stack.RejectIPv6WithICMPPortUnreachable, true))
}
}

View File

@@ -1,236 +0,0 @@
//go:build with_gvisor
package tun
import (
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
gHdr "github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
)
type Mixed struct {
*System
stack *stack.Stack
endpoint *channel.Endpoint
}
func NewMixed(
options StackOptions,
) (Stack, error) {
system, err := NewSystem(options)
if err != nil {
return nil, err
}
return &Mixed{
System: system.(*System),
}, nil
}
func (m *Mixed) Start() error {
err := m.System.start()
if err != nil {
return err
}
endpoint := channel.New(1024, uint32(m.mtu), "")
ipStack, err := NewGVisorStack(endpoint)
if err != nil {
return err
}
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
m.stack = ipStack
m.endpoint = endpoint
go m.tunLoop()
go m.packetLoop()
return nil
}
func (m *Mixed) Close() error {
if m.stack == nil {
return nil
}
m.endpoint.Attach(nil)
m.stack.Close()
for _, endpoint := range m.stack.CleanupEndpoints() {
endpoint.Abort()
}
return m.System.Close()
}
func (m *Mixed) tunLoop() {
if winTun, isWinTun := m.tun.(WinTun); isWinTun {
m.wintunLoop(winTun)
return
}
if linuxTUN, isLinuxTUN := m.tun.(LinuxTUN); isLinuxTUN {
m.frontHeadroom = linuxTUN.FrontHeadroom()
m.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
m.batchLoop(linuxTUN, batchSize)
return
}
}
packetBuffer := make([]byte, m.mtu+PacketOffset)
for {
n, err := m.tun.Read(packetBuffer)
if err != nil {
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "read packet"))
}
if n < header.IPv4MinimumSize {
continue
}
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
if m.processPacket(packet) {
_, err = m.tun.Write(rawPacket)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}
func (m *Mixed) wintunLoop(winTun WinTun) {
for {
packet, release, err := winTun.ReadPacket()
if err != nil {
return
}
if len(packet) < header.IPv4MinimumSize {
release()
continue
}
if m.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
m.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
func (m *Mixed) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, m.mtu+m.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, m.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
}
m.logger.Error(E.Cause(err, "batch read packet"))
}
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetSize := packetSizes[i]
if packetSize < header.IPv4MinimumSize {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[m.frontHeadroom : m.frontHeadroom+packetSize]
if m.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:m.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
_, err = linuxTUN.BatchWrite(writeBuffers, m.frontHeadroom)
if err != nil {
m.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}
func (m *Mixed) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
writeBack, err = m.processIPv4(packet)
case header.IPv6Version:
writeBack, err = m.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
m.logger.Trace(err)
return false
}
return writeBack
}
func (m *Mixed) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
writeBack = true
destination := ipHdr.DestinationAddr()
if destination == m.broadcastAddr || !destination.IsGlobalUnicast() {
return
}
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
writeBack, err = m.processIPv4TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(ipHdr),
IsForwardedPacket: true,
})
m.endpoint.InjectInbound(gHdr.IPv4ProtocolNumber, pkt)
pkt.DecRef()
return
case header.ICMPv4ProtocolNumber:
err = m.processIPv4ICMP(ipHdr, ipHdr.Payload())
}
return
}
func (m *Mixed) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
writeBack = true
if !ipHdr.DestinationAddr().IsGlobalUnicast() {
return
}
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
writeBack, err = m.processIPv6TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
writeBack = false
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(ipHdr),
IsForwardedPacket: true,
})
m.endpoint.InjectInbound(tcpip.NetworkProtocolNumber(header.IPv6ProtocolNumber), pkt)
pkt.DecRef()
case header.ICMPv6ProtocolNumber:
err = m.processIPv6ICMP(ipHdr, ipHdr.Payload())
}
return
}
func (m *Mixed) packetLoop() {
for {
packet := m.endpoint.ReadContext(m.ctx)
if packet == nil {
break
}
bufio.WriteVectorised(m.tun, packet.AsSlices())
packet.DecRef()
}
}

View File

@@ -1,769 +0,0 @@
package tun
import (
"context"
"net"
"net/netip"
"syscall"
"time"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/control"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/udpnat2"
)
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
type System struct {
ctx context.Context
tun Tun
tunName string
mtu int
handler Handler
logger logger.Logger
inet4Prefixes []netip.Prefix
inet6Prefixes []netip.Prefix
inet4ServerAddress netip.Addr
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
udpTimeout time.Duration
tcpListener net.Listener
tcpListener6 net.Listener
tcpPort uint16
tcpPort6 uint16
tcpNat *TCPNat
udpNat *udpnat.Service
bindInterface bool
interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
}
type Session struct {
SourceAddress netip.Addr
DestinationAddress netip.Addr
SourcePort uint16
DestinationPort uint16
}
func NewSystem(options StackOptions) (Stack, error) {
stack := &System{
ctx: options.Context,
tun: options.Tun,
tunName: options.TunOptions.Name,
mtu: int(options.TunOptions.MTU),
udpTimeout: options.UDPTimeout,
handler: options.Handler,
logger: options.Logger,
inet4Prefixes: options.TunOptions.Inet4Address,
inet6Prefixes: options.TunOptions.Inet6Address,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
}
if len(options.TunOptions.Inet4Address) > 0 {
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
return nil, E.New("need one more IPv4 address in first prefix for system stack")
}
stack.inet4ServerAddress = options.TunOptions.Inet4Address[0].Addr()
stack.inet4Address = stack.inet4ServerAddress.Next()
}
if len(options.TunOptions.Inet6Address) > 0 {
if !HasNextAddress(options.TunOptions.Inet6Address[0], 1) {
return nil, E.New("need one more IPv6 address in first prefix for system stack")
}
stack.inet6ServerAddress = options.TunOptions.Inet6Address[0].Addr()
stack.inet6Address = stack.inet6ServerAddress.Next()
}
if !stack.inet4Address.IsValid() && !stack.inet6Address.IsValid() {
return nil, E.New("missing interface address")
}
return stack, nil
}
func (s *System) Close() error {
return common.Close(
s.tcpListener,
s.tcpListener6,
)
}
func (s *System) Start() error {
err := s.start()
if err != nil {
return err
}
go s.tunLoop()
return nil
}
func (s *System) start() error {
_ = fixWindowsFirewall()
var listener net.ListenConfig
if s.bindInterface {
listener.Control = control.Append(listener.Control, func(network, address string, conn syscall.RawConn) error {
bindErr := control.BindToInterface0(s.interfaceFinder, conn, network, address, s.tunName, -1, true)
if bindErr != nil {
s.logger.Warn("bind forwarder to interface: ", bindErr)
}
return nil
})
}
var tcpListener net.Listener
var err error
if s.inet4Address.IsValid() {
for i := 0; i < 3; i++ {
tcpListener, err = listener.Listen(s.ctx, "tcp4", net.JoinHostPort(s.inet4ServerAddress.String(), "0"))
if !retryableListenError(err) {
break
}
time.Sleep(time.Second)
}
if err != nil {
return err
}
s.tcpListener = tcpListener
s.tcpPort = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener)
}
if s.inet6Address.IsValid() {
for i := 0; i < 3; i++ {
tcpListener, err = listener.Listen(s.ctx, "tcp6", net.JoinHostPort(s.inet6ServerAddress.String(), "0"))
if !retryableListenError(err) {
break
}
time.Sleep(time.Second)
}
if err != nil {
return err
}
s.tcpListener6 = tcpListener
s.tcpPort6 = M.SocksaddrFromNet(tcpListener.Addr()).Port
go s.acceptLoop(tcpListener)
}
s.tcpNat = NewNat(s.ctx, s.udpTimeout)
s.udpNat = udpnat.New(s.handler, s.preparePacketConnection, s.udpTimeout, false)
return nil
}
func (s *System) tunLoop() {
if winTun, isWinTun := s.tun.(WinTun); isWinTun {
s.wintunLoop(winTun)
return
}
if linuxTUN, isLinuxTUN := s.tun.(LinuxTUN); isLinuxTUN {
s.frontHeadroom = linuxTUN.FrontHeadroom()
s.txChecksumOffload = linuxTUN.TXChecksumOffload()
batchSize := linuxTUN.BatchSize()
if batchSize > 1 {
s.batchLoop(linuxTUN, batchSize)
return
}
}
packetBuffer := make([]byte, s.mtu+PacketOffset)
for {
n, err := s.tun.Read(packetBuffer)
if err != nil {
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "read packet"))
}
if n < header.IPv4MinimumSize {
continue
}
rawPacket := packetBuffer[:n]
packet := packetBuffer[PacketOffset:n]
if s.processPacket(packet) {
_, err = s.tun.Write(rawPacket)
if err != nil {
s.logger.Trace(E.Cause(err, "write packet"))
}
}
}
}
func (s *System) wintunLoop(winTun WinTun) {
for {
packet, release, err := winTun.ReadPacket()
if err != nil {
return
}
if len(packet) < header.IPv4MinimumSize {
release()
continue
}
if s.processPacket(packet) {
_, err = winTun.Write(packet)
if err != nil {
s.logger.Trace(E.Cause(err, "write packet"))
}
}
release()
}
}
func (s *System) batchLoop(linuxTUN LinuxTUN, batchSize int) {
packetBuffers := make([][]byte, batchSize)
writeBuffers := make([][]byte, batchSize)
packetSizes := make([]int, batchSize)
for i := range packetBuffers {
packetBuffers[i] = make([]byte, s.mtu+s.frontHeadroom)
}
for {
n, err := linuxTUN.BatchRead(packetBuffers, s.frontHeadroom, packetSizes)
if err != nil {
if E.IsClosed(err) {
return
}
s.logger.Error(E.Cause(err, "batch read packet"))
}
if n == 0 {
continue
}
for i := 0; i < n; i++ {
packetSize := packetSizes[i]
if packetSize < header.IPv4MinimumSize {
continue
}
packetBuffer := packetBuffers[i]
packet := packetBuffer[s.frontHeadroom : s.frontHeadroom+packetSize]
if s.processPacket(packet) {
writeBuffers = append(writeBuffers, packetBuffer[:s.frontHeadroom+packetSize])
}
}
if len(writeBuffers) > 0 {
_, err = linuxTUN.BatchWrite(writeBuffers, s.frontHeadroom)
if err != nil {
s.logger.Trace(E.Cause(err, "batch write packet"))
}
writeBuffers = writeBuffers[:0]
}
}
}
func (s *System) processPacket(packet []byte) bool {
var (
writeBack bool
err error
)
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
writeBack, err = s.processIPv4(packet)
case header.IPv6Version:
writeBack, err = s.processIPv6(packet)
default:
err = E.New("ip: unknown version: ", ipVersion)
}
if err != nil {
s.logger.Trace(err)
return false
}
return writeBack
}
func (s *System) acceptLoop(listener net.Listener) {
for {
conn, err := listener.Accept()
if err != nil {
return
}
connPort := M.SocksaddrFromNet(conn.RemoteAddr()).Port
session := s.tcpNat.LookupBack(connPort)
if session == nil {
s.logger.Trace(E.New("unknown session with port ", connPort))
continue
}
destination := M.SocksaddrFromNetIP(session.Destination)
if destination.Addr.Is4() {
for _, prefix := range s.inet4Prefixes {
if prefix.Contains(destination.Addr) {
destination.Addr = netip.AddrFrom4([4]byte{127, 0, 0, 1})
break
}
}
} else {
for _, prefix := range s.inet6Prefixes {
if prefix.Contains(destination.Addr) {
destination.Addr = netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
break
}
}
}
go s.handler.NewConnectionEx(s.ctx, conn, M.SocksaddrFromNetIP(session.Source), destination, nil)
}
}
func (s *System) processIPv4(ipHdr header.IPv4) (writeBack bool, err error) {
destination := ipHdr.DestinationAddr()
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
return
}
writeBack = true
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
writeBack, err = s.processIPv4TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
writeBack = false
err = s.processIPv4UDP(ipHdr, ipHdr.Payload())
case header.ICMPv4ProtocolNumber:
err = s.processIPv4ICMP(ipHdr, ipHdr.Payload())
}
return
}
func (s *System) processIPv6(ipHdr header.IPv6) (writeBack bool, err error) {
if !ipHdr.DestinationAddr().IsGlobalUnicast() {
return
}
writeBack = true
switch ipHdr.TransportProtocol() {
case header.TCPProtocolNumber:
writeBack, err = s.processIPv6TCP(ipHdr, ipHdr.Payload())
case header.UDPProtocolNumber:
err = s.processIPv6UDP(ipHdr, ipHdr.Payload())
case header.ICMPv6ProtocolNumber:
err = s.processIPv6ICMP(ipHdr, ipHdr.Payload())
}
return
}
func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, error) {
source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort())
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return false, nil
} else if source.Addr() == s.inet4ServerAddress && source.Port() == s.tcpPort {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return false, E.New("ipv4: tcp: session not found: ", destination.Port())
}
ipHdr.SetSourceAddr(session.Destination.Addr())
tcpHdr.SetSourcePort(session.Destination.Port())
ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port())
} else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
if err == ErrDrop {
return false, nil
} else {
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet4Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort)
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
tcpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return true, nil
}
func (s *System) resetIPv4TCP(origIPHdr header.IPv4, origTCPHdr header.TCP) error {
frontHeadroom := s.frontHeadroom + PacketOffset
newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.TCPMinimumSize)
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.TCPMinimumSize)
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(newPacket.Len()),
Protocol: uint8(header.TCPProtocolNumber),
SrcAddr: origIPHdr.DestinationAddr(),
DstAddr: origIPHdr.SourceAddr(),
})
tcpHdr := header.TCP(ipHdr.Payload())
fields := header.TCPFields{
SrcPort: origTCPHdr.DestinationPort(),
DstPort: origTCPHdr.SourcePort(),
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst,
}
if origTCPHdr.Flags()&header.TCPFlagAck != 0 {
fields.SeqNum = origTCPHdr.AckNumber()
} else {
fields.Flags |= header.TCPFlagAck
ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload()))
if origTCPHdr.Flags()&header.TCPFlagSyn != 0 {
ackNum++
}
if origTCPHdr.Flags()&header.TCPFlagFin != 0 {
ackNum++
}
fields.AckNum = ackNum
}
tcpHdr.Encode(&fields)
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, error) {
source := netip.AddrPortFrom(ipHdr.SourceAddr(), tcpHdr.SourcePort())
destination := netip.AddrPortFrom(ipHdr.DestinationAddr(), tcpHdr.DestinationPort())
if !destination.Addr().IsGlobalUnicast() {
return false, nil
} else if source.Addr() == s.inet6ServerAddress && source.Port() == s.tcpPort6 {
session := s.tcpNat.LookupBack(destination.Port())
if session == nil {
return false, E.New("ipv6: tcp: session not found: ", destination.Port())
}
ipHdr.SetSourceAddr(session.Destination.Addr())
tcpHdr.SetSourcePort(session.Destination.Port())
ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port())
} else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
if err != nil {
if err == ErrDrop {
return false, nil
} else {
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet6Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort6)
}
if !s.txChecksumOffload {
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
tcpHdr.SetChecksum(0)
}
return true, nil
}
func (s *System) resetIPv6TCP(origIPHdr header.IPv6, origTCPHdr header.TCP) error {
frontHeadroom := s.frontHeadroom + PacketOffset
newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.TCPMinimumSize)
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.TCPMinimumSize)
ipHdr := header.IPv6(newPacket.Bytes())
ipHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.TCPMinimumSize),
TransportProtocol: header.TCPProtocolNumber,
SrcAddr: origIPHdr.DestinationAddr(),
DstAddr: origIPHdr.SourceAddr(),
})
tcpHdr := header.TCP(ipHdr.Payload())
fields := header.TCPFields{
SrcPort: origTCPHdr.DestinationPort(),
DstPort: origTCPHdr.SourcePort(),
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst,
}
if origTCPHdr.Flags()&header.TCPFlagAck != 0 {
fields.SeqNum = origTCPHdr.AckNumber()
} else {
fields.Flags |= header.TCPFlagAck
ackNum := origTCPHdr.SequenceNumber() + uint32(len(origTCPHdr.Payload()))
if origTCPHdr.Flags()&header.TCPFlagSyn != 0 {
ackNum++
}
if origTCPHdr.Flags()&header.TCPFlagFin != 0 {
ackNum++
}
fields.AckNum = ackNum
}
tcpHdr.Encode(&fields)
if !s.txChecksumOffload {
tcpHdr.SetChecksum(^tcpHdr.CalculateChecksum(header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), header.TCPMinimumSize)))
}
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv4UDP(ipHdr header.IPv4, udpHdr header.UDP) error {
if ipHdr.Flags()&header.IPv4FlagMoreFragments != 0 {
return E.New("ipv4: fragment dropped")
}
if ipHdr.FragmentOffset() != 0 {
return E.New("ipv4: udp: fragment dropped")
}
source := M.SocksaddrFrom(ipHdr.SourceAddr(), udpHdr.SourcePort())
destination := M.SocksaddrFrom(ipHdr.DestinationAddr(), udpHdr.DestinationPort())
if !destination.Addr.IsGlobalUnicast() {
return nil
}
s.udpNat.NewPacket([][]byte{udpHdr.Payload()}, source, destination, ipHdr)
return nil
}
func (s *System) processIPv6UDP(ipHdr header.IPv6, udpHdr header.UDP) error {
source := M.SocksaddrFrom(ipHdr.SourceAddr(), udpHdr.SourcePort())
destination := M.SocksaddrFrom(ipHdr.DestinationAddr(), udpHdr.DestinationPort())
if !destination.Addr.IsGlobalUnicast() {
return nil
}
s.udpNat.NewPacket([][]byte{udpHdr.Payload()}, source, destination, ipHdr)
return nil
}
func (s *System) preparePacketConnection(source M.Socksaddr, destination M.Socksaddr, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) {
pErr := s.handler.PrepareConnection(N.NetworkUDP, source, destination)
if pErr != nil {
if pErr != ErrDrop {
if source.IsIPv4() {
ipHdr := userData.(header.IPv4)
s.rejectIPv4WithICMP(ipHdr, header.ICMPv4PortUnreachable)
} else {
ipHdr := userData.(header.IPv6)
s.rejectIPv6WithICMP(ipHdr, header.ICMPv6PortUnreachable)
}
}
return false, nil, nil, nil
}
var writer N.PacketWriter
if source.IsIPv4() {
packet := userData.(header.IPv4)
headerLen := packet.HeaderLength() + header.UDPMinimumSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
writer = &systemUDPPacketWriter4{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source.AddrPort(),
s.txChecksumOffload,
}
} else {
packet := userData.(header.IPv6)
headerLen := len(packet) - int(packet.PayloadLength()) + header.UDPMinimumSize
headerCopy := make([]byte, headerLen)
copy(headerCopy, packet[:headerLen])
writer = &systemUDPPacketWriter6{
s.tun,
s.frontHeadroom + PacketOffset,
headerCopy,
source.AddrPort(),
s.txChecksumOffload,
}
}
return true, s.ctx, writer, nil
}
func (s *System) processIPv4ICMP(ipHdr header.IPv4, icmpHdr header.ICMPv4) error {
if icmpHdr.Type() != header.ICMPv4Echo || icmpHdr.Code() != 0 {
return nil
}
icmpHdr.SetType(header.ICMPv4EchoReply)
sourceAddress := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(icmpHdr.Payload(), 0)))
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
return nil
}
func (s *System) rejectIPv4WithICMP(ipHdr header.IPv4, code header.ICMPv4Code) error {
frontHeadroom := s.frontHeadroom + PacketOffset
mtu := s.mtu
const maxIPData = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize
if mtu > maxIPData {
mtu = maxIPData
}
available := mtu - header.ICMPv4MinimumSize
if available < len(ipHdr)+header.ICMPv4MinimumErrorPayloadSize {
return nil
}
payload := ipHdr
if len(payload) > available {
payload = payload[:available]
}
newPacket := buf.NewSize(frontHeadroom + header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(payload))
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv4MinimumSize+header.ICMPv4MinimumSize+len(payload))
newIPHdr := header.IPv4(newPacket.Bytes())
newIPHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(newPacket.Len()),
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: ipHdr.DestinationAddr(),
DstAddr: ipHdr.SourceAddr(),
})
newIPHdr.SetChecksum(^newIPHdr.CalculateChecksum())
icmpHdr := header.ICMPv4(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv4DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr[:header.ICMPv4MinimumSize], checksum.Checksum(payload, 0)))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
newPacket.ExtendHeader(PacketOffset)[3] = syscall.AF_INET
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
func (s *System) processIPv6ICMP(ipHdr header.IPv6, icmpHdr header.ICMPv6) error {
if icmpHdr.Type() != header.ICMPv6EchoRequest || icmpHdr.Code() != 0 {
return nil
}
icmpHdr.SetType(header.ICMPv6EchoReply)
sourceAddress := ipHdr.SourceAddr()
ipHdr.SetSourceAddr(ipHdr.DestinationAddr())
ipHdr.SetDestinationAddr(sourceAddress)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr,
Src: ipHdr.SourceAddress(),
Dst: ipHdr.DestinationAddress(),
}))
return nil
}
func (s *System) rejectIPv6WithICMP(ipHdr header.IPv6, code header.ICMPv6Code) error {
frontHeadroom := s.frontHeadroom + PacketOffset
mtu := s.mtu
const maxIPv6Data = header.IPv6MinimumMTU - header.IPv6FixedHeaderSize
if mtu > maxIPv6Data {
mtu = maxIPv6Data
}
available := mtu - header.ICMPv6ErrorHeaderSize
if available < header.IPv6MinimumSize {
return nil
}
payload := ipHdr
if len(payload) > available {
payload = payload[:available]
}
newPacket := buf.NewSize(frontHeadroom + header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + len(payload))
defer newPacket.Release()
newPacket.Resize(frontHeadroom, header.IPv6MinimumSize+header.ICMPv6DstUnreachableMinimumSize+len(payload))
newIPHdr := header.IPv6(newPacket.Bytes())
newIPHdr.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.ICMPv6DstUnreachableMinimumSize + len(payload)),
TransportProtocol: header.ICMPv6ProtocolNumber,
SrcAddr: ipHdr.DestinationAddr(),
DstAddr: ipHdr.SourceAddr(),
})
icmpHdr := header.ICMPv6(newIPHdr.Payload())
icmpHdr.SetType(header.ICMPv6DstUnreachable)
icmpHdr.SetCode(code)
icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
Header: icmpHdr[:header.ICMPv6DstUnreachableMinimumSize],
Src: newIPHdr.SourceAddress(),
Dst: newIPHdr.DestinationAddress(),
PayloadCsum: checksum.Checksum(payload, 0),
PayloadLen: len(payload),
}))
copy(icmpHdr.Payload(), payload)
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-s.frontHeadroom)
}
return common.Error(s.tun.Write(newPacket.Bytes()))
}
type systemUDPPacketWriter4 struct {
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}
func (w *systemUDPPacketWriter4) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := header.IPv4(newPacket.Bytes())
ipHdr.SetTotalLength(uint16(newPacket.Len()))
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
ipHdr.SetSourceAddr(destination.Addr)
udpHdr := header.UDP(ipHdr.Payload())
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(uint16(buffer.Len() + header.UDPMinimumSize))
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
udpHdr.SetChecksum(0)
}
ipHdr.SetChecksum(0)
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv4Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}
type systemUDPPacketWriter6 struct {
tun Tun
frontHeadroom int
header []byte
source netip.AddrPort
txChecksumOffload bool
}
func (w *systemUDPPacketWriter6) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
newPacket := buf.NewSize(w.frontHeadroom + len(w.header) + buffer.Len())
defer newPacket.Release()
newPacket.Resize(w.frontHeadroom, 0)
newPacket.Write(w.header)
newPacket.Write(buffer.Bytes())
ipHdr := header.IPv6(newPacket.Bytes())
udpLen := uint16(header.UDPMinimumSize + buffer.Len())
ipHdr.SetPayloadLength(udpLen)
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
ipHdr.SetSourceAddr(destination.Addr)
udpHdr := header.UDP(ipHdr.Payload())
udpHdr.SetDestinationPort(udpHdr.SourcePort())
udpHdr.SetSourcePort(destination.Port)
udpHdr.SetLength(udpLen)
if !w.txChecksumOffload {
udpHdr.SetChecksum(0)
udpHdr.SetChecksum(^checksum.Checksum(udpHdr.Payload(), udpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.UDPProtocolNumber, ipHdr.SourceAddressSlice(), ipHdr.DestinationAddressSlice(), ipHdr.PayloadLength()),
)))
} else {
udpHdr.SetChecksum(0)
}
if PacketOffset > 0 {
PacketFillHeader(newPacket.ExtendHeader(PacketOffset), header.IPv6Version)
} else {
newPacket.Advance(-w.frontHeadroom)
}
return common.Error(w.tun.Write(newPacket.Bytes()))
}

View File

@@ -1,17 +0,0 @@
//go:build !windows
package tun
import (
"errors"
"golang.org/x/sys/unix"
)
func fixWindowsFirewall() error {
return nil
}
func retryableListenError(err error) bool {
return errors.Is(err, unix.EADDRNOTAVAIL)
}

View File

@@ -1,34 +0,0 @@
package tun
import (
"net/netip"
"syscall"
"github.com/sagernet/sing-tun/internal/gtcpip/header"
)
func PacketIPVersion(packet []byte) int {
return header.IPVersion(packet)
}
func PacketFillHeader(packet []byte, ipVersion int) {
if PacketOffset > 0 {
switch ipVersion {
case header.IPv4Version:
packet[3] = syscall.AF_INET
case header.IPv6Version:
packet[3] = syscall.AF_INET6
}
}
}
func PacketDestination(packet []byte) netip.Addr {
switch ipVersion := header.IPVersion(packet); ipVersion {
case header.IPv4Version:
return header.IPv4(packet).DestinationAddr()
case header.IPv6Version:
return header.IPv6(packet).DestinationAddr()
default:
return netip.Addr{}
}
}

Some files were not shown because too many files have changed in this diff Show More