Add loopback address support

This commit is contained in:
世界
2025-06-09 18:51:17 +08:00
parent f57754918d
commit 5e343c4b66
7 changed files with 524 additions and 183 deletions

View File

@@ -46,6 +46,11 @@ func (r *autoRedirect) setupNFTables() error {
return err return err
} }
err = r.nftablesCreateLoopbackAddressSets(nft, table)
if err != nil {
return err
}
skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo") skipOutput := len(r.tunOptions.IncludeInterface) > 0 && !common.Contains(r.tunOptions.IncludeInterface, "lo") || common.Contains(r.tunOptions.ExcludeInterface, "lo")
if !skipOutput { if !skipOutput {
chainOutput := nft.AddChain(&nftables.Chain{ chainOutput := nft.AddChain(&nftables.Chain{
@@ -61,8 +66,23 @@ func (r *autoRedirect) setupNFTables() error {
return err return err
} }
r.nftablesCreateUnreachable(nft, table, chainOutput) r.nftablesCreateUnreachable(nft, table, chainOutput)
r.nftablesCreateRedirect(nft, table, chainOutput) err = r.nftablesCreateRedirect(nft, table, chainOutput)
if err != nil {
return err
}
if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 {
chainOutputRoute := nft.AddChain(&nftables.Chain{
Name: "output_route",
Table: table,
Hooknum: nftables.ChainHookOutput,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeRoute,
})
err = r.nftablesCreateLoopbackReroute(nft, table, chainOutputRoute)
if err != nil {
return err
}
}
chainOutputUDP := nft.AddChain(&nftables.Chain{ chainOutputUDP := nft.AddChain(&nftables.Chain{
Name: "output_udp_icmp", Name: "output_udp_icmp",
Table: table, Table: table,
@@ -77,7 +97,7 @@ func (r *autoRedirect) setupNFTables() error {
r.nftablesCreateUnreachable(nft, table, chainOutputUDP) r.nftablesCreateUnreachable(nft, table, chainOutputUDP)
r.nftablesCreateMark(nft, table, chainOutputUDP) r.nftablesCreateMark(nft, table, chainOutputUDP)
} else { } else {
r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{ err = r.nftablesCreateRedirect(nft, table, chainOutput, &expr.Meta{
Key: expr.MetaKeyOIFNAME, Key: expr.MetaKeyOIFNAME,
Register: 1, Register: 1,
}, &expr.Cmp{ }, &expr.Cmp{
@@ -85,6 +105,9 @@ func (r *autoRedirect) setupNFTables() error {
Register: 1, Register: 1,
Data: nftablesIfname(r.tunOptions.Name), Data: nftablesIfname(r.tunOptions.Name),
}) })
if err != nil {
return err
}
} }
} }
@@ -100,12 +123,25 @@ func (r *autoRedirect) setupNFTables() error {
return err return err
} }
r.nftablesCreateUnreachable(nft, table, chainPreRouting) r.nftablesCreateUnreachable(nft, table, chainPreRouting)
r.nftablesCreateRedirect(nft, table, chainPreRouting) err = r.nftablesCreateRedirect(nft, table, chainPreRouting)
if err != nil {
return err
}
if r.tunOptions.AutoRedirectMarkMode { if r.tunOptions.AutoRedirectMarkMode {
r.nftablesCreateMark(nft, table, chainPreRouting) r.nftablesCreateMark(nft, table, chainPreRouting)
} if len(r.tunOptions.Inet4LoopbackAddress) > 0 || len(r.tunOptions.Inet6LoopbackAddress) > 0 {
chainPreRoutingFilter := nft.AddChain(&nftables.Chain{
if r.tunOptions.AutoRedirectMarkMode { Name: "prerouting_filter",
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityRef(*nftables.ChainPriorityNATDest + 1),
Type: nftables.ChainTypeFilter,
})
err = r.nftablesCreateLoopbackReroute(nft, table, chainPreRoutingFilter)
if err != nil {
return err
}
}
chainPreRoutingUDP := nft.AddChain(&nftables.Chain{ chainPreRoutingUDP := nft.AddChain(&nftables.Chain{
Name: "prerouting_udp", Name: "prerouting_udp",
Table: table, Table: table,

View File

@@ -7,6 +7,7 @@ import (
"github.com/sagernet/nftables" "github.com/sagernet/nftables"
"github.com/sagernet/nftables/expr" "github.com/sagernet/nftables/expr"
"github.com/sagernet/sing/common"
"go4.org/netipx" "go4.org/netipx"
) )
@@ -21,6 +22,20 @@ func nftablesCreateExcludeDestinationIPSet(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
id uint32, name string, family nftables.TableFamily, invert bool, id uint32, name string, family nftables.TableFamily, invert bool,
) { ) {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(
nftablesCreateDestinationIPSetExprs(id, name, family, invert),
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
),
})
}
func nftablesCreateDestinationIPSetExprs(id uint32, name string, family nftables.TableFamily, invert bool) []expr.Any {
exprs := []expr.Any{ exprs := []expr.Any{
&expr.Meta{ &expr.Meta{
Key: expr.MetaKeyNFPROTO, Key: expr.MetaKeyNFPROTO,
@@ -53,22 +68,63 @@ func nftablesCreateExcludeDestinationIPSet(
}, },
) )
} }
exprs = append(exprs, exprs = append(exprs, &expr.Lookup{
&expr.Lookup{ SourceRegister: 1,
SourceRegister: 1, SetID: id,
SetID: id, SetName: name,
SetName: name, Invert: invert,
Invert: invert,
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
})
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs,
}) })
return exprs
}
func nftablesCreateIPConst(
nft *nftables.Conn, table *nftables.Table, id uint32, name string, family nftables.TableFamily, addressList []netip.Addr,
) (*nftables.Set, error) {
var keyType nftables.SetDatatype
if family == nftables.TableFamilyIPv4 {
keyType = nftables.TypeIPAddr
} else {
keyType = nftables.TypeIP6Addr
}
mySet := &nftables.Set{
Table: table,
ID: id,
Name: name,
KeyType: keyType,
Constant: true,
}
if id == 0 {
mySet.Anonymous = true
}
setElements := common.Map(addressList, func(addr netip.Addr) nftables.SetElement { return nftables.SetElement{Key: addr.AsSlice()} })
if id == 0 {
err := nft.AddSet(mySet, setElements)
if err != nil {
return nil, err
}
return mySet, nil
} else {
err := nft.AddSet(mySet, nil)
if err != nil {
return nil, err
}
}
for len(setElements) > 0 {
toAdd := setElements
if len(toAdd) > 1000 {
toAdd = toAdd[:1000]
}
setElements = setElements[len(toAdd):]
err := nft.SetAddElements(mySet, toAdd)
if err != nil {
return nil, err
}
err = nft.Flush()
if err != nil {
return nil, err
}
}
return mySet, nil
} }
func nftablesCreateIPSet( func nftablesCreateIPSet(

View File

@@ -117,8 +117,61 @@ func (r *autoRedirect) nftablesCreateLocalAddressSets(
return nil return nil
} }
func (r *autoRedirect) nftablesCreateLoopbackAddressSets(
nft *nftables.Conn, table *nftables.Table,
) error {
if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 {
_, err := nftablesCreateIPConst(nft, table, 7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, r.tunOptions.Inet4LoopbackAddress)
if err != nil {
return err
}
}
if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 {
_, err := nftablesCreateIPConst(nft, table, 8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, r.tunOptions.Inet6LoopbackAddress)
if err != nil {
return err
}
}
return nil
}
func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error { func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error {
if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput { if r.tunOptions.AutoRedirectMarkMode && chain.Hooknum == nftables.ChainHookOutput {
if chain.Type == nftables.ChainTypeRoute {
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err := nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_UDP}},
{Key: []byte{unix.IPPROTO_ICMP}},
{Key: []byte{unix.IPPROTO_ICMPV6}},
})
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
Invert: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
nft.AddRule(&nftables.Rule{ nft.AddRule(&nftables.Rule{
Table: table, Table: table,
Chain: chain, Chain: chain,
@@ -161,6 +214,25 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
} }
} }
if chain.Hooknum == nftables.ChainHookPrerouting { if chain.Hooknum == nftables.ChainHookPrerouting {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyIIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: nftablesIfname(r.tunOptions.Name),
},
&expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
if len(r.tunOptions.IncludeInterface) > 0 { if len(r.tunOptions.IncludeInterface) > 0 {
if len(r.tunOptions.IncludeInterface) > 1 { if len(r.tunOptions.IncludeInterface) > 1 {
includeInterface := &nftables.Set{ includeInterface := &nftables.Set{
@@ -436,44 +508,6 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft
} }
} }
if r.tunOptions.AutoRedirectMarkMode &&
((chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeRoute) ||
(chain.Hooknum == nftables.ChainHookPrerouting && chain.Type == nftables.ChainTypeFilter)) {
ipProto := &nftables.Set{
Table: table,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetProto,
}
err := nft.AddSet(ipProto, []nftables.SetElement{
{Key: []byte{unix.IPPROTO_UDP}},
{Key: []byte{unix.IPPROTO_ICMP}},
{Key: []byte{unix.IPPROTO_ICMPV6}},
})
if err != nil {
return err
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Lookup{
SourceRegister: 1,
SetID: ipProto.ID,
SetName: ipProto.Name,
Invert: true,
},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
},
})
}
if r.enableIPv4 { if r.enableIPv4 {
nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false) nftablesCreateExcludeDestinationIPSet(nft, table, chain, 5, "inet4_local_address_set", nftables.TableFamilyIPv4, false)
} }
@@ -527,6 +561,9 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta
SourceRegister: true, SourceRegister: true,
}, },
&expr.Counter{}, &expr.Counter{},
&expr.Verdict{
Kind: expr.VerdictReturn,
},
}, },
}) })
} }
@@ -534,57 +571,193 @@ func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Ta
func (r *autoRedirect) nftablesCreateRedirect( func (r *autoRedirect) nftablesCreateRedirect(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
exprs ...expr.Any, exprs ...expr.Any,
) { ) error {
if r.enableIPv4 && !r.enableIPv6 { exprsRedirect := []expr.Any{
exprs = append(exprs, &expr.Meta{
&expr.Meta{ Key: expr.MetaKeyL4PROTO,
Key: expr.MetaKeyNFPROTO, Register: 1,
Register: 1, },
}, &expr.Cmp{
&expr.Cmp{ Op: expr.CmpOpEq,
Op: expr.CmpOpEq, Register: 1,
Register: 1, Data: []byte{unix.IPPROTO_TCP},
Data: []byte{uint8(nftables.TableFamilyIPv4)}, },
}) &expr.Counter{},
} else if !r.enableIPv4 && r.enableIPv6 { &expr.Immediate{
exprs = append(exprs, Register: 1,
&expr.Meta{ Data: binaryutil.BigEndian.PutUint16(r.redirectPort()),
Key: expr.MetaKeyNFPROTO, },
Register: 1, &expr.Redir{
}, RegisterProtoMin: 1,
&expr.Cmp{ Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED,
Op: expr.CmpOpEq, },
Register: 1, &expr.Verdict{
Data: []byte{uint8(nftables.TableFamilyIPv6)}, Kind: expr.VerdictReturn,
}) },
} }
nft.AddRule(&nftables.Rule{ if len(r.tunOptions.Inet4LoopbackAddress) == 0 && len(r.tunOptions.Inet6LoopbackAddress) == 0 {
Table: table, if r.enableIPv4 && !r.enableIPv6 {
Chain: chain, exprs = append(exprs,
Exprs: append(exprs, &expr.Meta{
&expr.Meta{ Key: expr.MetaKeyNFPROTO,
Key: expr.MetaKeyL4PROTO, Register: 1,
Register: 1, },
}, &expr.Cmp{
&expr.Cmp{ Op: expr.CmpOpEq,
Op: expr.CmpOpEq, Register: 1,
Register: 1, Data: []byte{uint8(nftables.TableFamilyIPv4)},
Data: []byte{unix.IPPROTO_TCP}, })
}, } else if !r.enableIPv4 && r.enableIPv6 {
&expr.Counter{}, exprs = append(exprs,
&expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv6)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs, exprsRedirect...),
})
} else {
if r.enableIPv4 {
exprs4 := exprs
if len(r.tunOptions.Inet4LoopbackAddress) > 0 {
exprs4 = append(exprs4, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, true)...)
} else {
exprs4 = append(exprs4, &expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
}, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv4)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs4, exprsRedirect...),
})
}
if r.enableIPv6 {
exprs6 := exprs
if len(r.tunOptions.Inet6LoopbackAddress) > 0 {
exprs6 = append(exprs6, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, true)...)
} else {
exprs6 = append(exprs6, &expr.Meta{
Key: expr.MetaKeyNFPROTO,
Register: 1,
}, &expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{uint8(nftables.TableFamilyIPv6)},
})
}
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: append(exprs6, exprsRedirect...),
})
}
}
return nil
}
func (r *autoRedirect) nftablesCreateLoopbackReroute(
nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain,
) error {
exprs := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
},
}
var exprs4 []expr.Any
if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 {
exprs4 = append(exprs, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, false)...)
}
var exprs6 []expr.Any
if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 {
exprs6 = append(exprs, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, false)...)
}
var exprsCreateMark []expr.Any
if chain.Hooknum == nftables.ChainHookPrerouting {
exprsCreateMark = []expr.Any{
&expr.Immediate{ &expr.Immediate{
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(r.redirectPort()), Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
}, },
&expr.Redir{ &expr.Meta{
RegisterProtoMin: 1, Key: expr.MetaKeyMARK,
Flags: unix.NF_NAT_RANGE_PROTO_SPECIFIED, Register: 1,
SourceRegister: true,
}, },
&expr.Verdict{ &expr.Counter{},
Kind: expr.VerdictReturn, }
} else {
exprsCreateMark = []expr.Any{
&expr.Immediate{
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark),
}, },
), &expr.Meta{
}) Key: expr.MetaKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Meta{
Key: expr.MetaKeyMARK,
Register: 1,
},
&expr.Ct{
Key: expr.CtKeyMARK,
Register: 1,
SourceRegister: true,
},
&expr.Counter{},
}
}
if len(exprs4) > 0 {
exprs4 = append(exprs4, exprsCreateMark...)
}
if len(exprs6) > 0 {
exprs6 = append(exprs6, exprsCreateMark...)
}
if len(exprs4) > 0 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs4,
})
}
if len(exprs6) > 0 {
nft.AddRule(&nftables.Rule{
Table: table,
Chain: chain,
Exprs: exprs6,
})
}
return nil
} }
func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily( func (r *autoRedirect) nftablesCreateDNSHijackRulesForFamily(

View File

@@ -26,14 +26,16 @@ const WithGVisor = true
const DefaultNIC tcpip.NICID = 1 const DefaultNIC tcpip.NICID = 1
type GVisor struct { type GVisor struct {
ctx context.Context ctx context.Context
tun GVisorTun tun GVisorTun
udpTimeout time.Duration inet4LoopbackAddress []netip.Addr
broadcastAddr netip.Addr inet6LoopbackAddress []netip.Addr
handler Handler udpTimeout time.Duration
logger logger.Logger broadcastAddr netip.Addr
stack *stack.Stack handler Handler
endpoint stack.LinkEndpoint logger logger.Logger
stack *stack.Stack
endpoint stack.LinkEndpoint
} }
type GVisorTun interface { type GVisorTun interface {
@@ -50,12 +52,14 @@ func NewGVisor(
} }
gStack := &GVisor{ gStack := &GVisor{
ctx: options.Context, ctx: options.Context,
tun: gTun, tun: gTun,
udpTimeout: options.UDPTimeout, inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
handler: options.Handler, udpTimeout: options.UDPTimeout,
logger: options.Logger, broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
handler: options.Handler,
logger: options.Logger,
} }
return gStack, nil return gStack, nil
} }
@@ -70,7 +74,7 @@ func (t *GVisor) Start() error {
if err != nil { if err != nil {
return err return err
} }
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarder(t.ctx, ipStack, t.handler).HandlePacket) ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, NewTCPForwarderWithLoopback(t.ctx, ipStack, t.handler, t.inet4LoopbackAddress, t.inet6LoopbackAddress, t.tun).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket) ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(t.ctx, ipStack, t.handler, t.udpTimeout).HandlePacket)
t.stack = ipStack t.stack = ipStack
t.endpoint = linkEndpoint t.endpoint = linkEndpoint

View File

@@ -5,31 +5,75 @@ package tun
import ( import (
"context" "context"
"errors" "errors"
"net/netip"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack" "github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/sing-tun/internal/gtcpip/checksum"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type TCPForwarder struct { type TCPForwarder struct {
ctx context.Context ctx context.Context
stack *stack.Stack stack *stack.Stack
handler Handler handler Handler
forwarder *tcp.Forwarder inet4LoopbackAddress []tcpip.Address
inet6LoopbackAddress []tcpip.Address
tun GVisorTun
forwarder *tcp.Forwarder
} }
func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder { func NewTCPForwarder(ctx context.Context, stack *stack.Stack, handler Handler) *TCPForwarder {
return NewTCPForwarderWithLoopback(ctx, stack, handler, nil, nil, nil)
}
func NewTCPForwarderWithLoopback(ctx context.Context, stack *stack.Stack, handler Handler, inet4LoopbackAddress []netip.Addr, inet6LoopbackAddress []netip.Addr, tun GVisorTun) *TCPForwarder {
forwarder := &TCPForwarder{ forwarder := &TCPForwarder{
ctx: ctx, ctx: ctx,
stack: stack, stack: stack,
handler: handler, handler: handler,
inet4LoopbackAddress: common.Map(inet4LoopbackAddress, AddressFromAddr),
inet6LoopbackAddress: common.Map(inet6LoopbackAddress, AddressFromAddr),
tun: tun,
} }
forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward) forwarder.forwarder = tcp.NewForwarder(stack, 0, 1024, forwarder.Forward)
return forwarder return forwarder
} }
func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { func (f *TCPForwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
for _, inet4LoopbackAddress := range f.inet4LoopbackAddress {
if id.LocalAddress == inet4LoopbackAddress {
ipHdr := pkt.Network().(header.IPv4)
ipHdr.SetDestinationAddressWithChecksumUpdate(ipHdr.SourceAddress())
ipHdr.SetSourceAddressWithChecksumUpdate(inet4LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
bufio.WriteVectorised(f.tun, pkt.AsSlices())
return true
}
}
for _, inet6LoopbackAddress := range f.inet6LoopbackAddress {
if id.LocalAddress == inet6LoopbackAddress {
ipHdr := pkt.Network().(header.IPv6)
ipHdr.SetDestinationAddress(ipHdr.SourceAddress())
ipHdr.SetSourceAddress(inet6LoopbackAddress)
tcpHdr := header.TCP(pkt.TransportHeader().Slice())
tcpHdr.SetChecksum(0)
tcpHdr.SetChecksum(^checksum.Checksum(tcpHdr.Payload(), tcpHdr.CalculateChecksum(
header.PseudoHeaderChecksum(header.TCPProtocolNumber, ipHdr.SourceAddress(), ipHdr.DestinationAddress(), ipHdr.PayloadLength()),
)))
bufio.WriteVectorised(f.tun, pkt.AsSlices())
return true
}
}
return f.forwarder.HandlePacket(id, pkt) return f.forwarder.HandlePacket(id, pkt)
} }

View File

@@ -23,30 +23,32 @@ import (
var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25") var ErrIncludeAllNetworks = E.New("`system` and `mixed` stack are not available when `includeAllNetworks` is enabled. See https://github.com/SagerNet/sing-tun/issues/25")
type System struct { type System struct {
ctx context.Context ctx context.Context
tun Tun tun Tun
tunName string tunName string
mtu int mtu int
handler Handler handler Handler
logger logger.Logger logger logger.Logger
inet4Prefixes []netip.Prefix inet4Prefixes []netip.Prefix
inet6Prefixes []netip.Prefix inet6Prefixes []netip.Prefix
inet4ServerAddress netip.Addr inet4ServerAddress netip.Addr
inet4Address netip.Addr inet4Address netip.Addr
inet6ServerAddress netip.Addr inet6ServerAddress netip.Addr
inet6Address netip.Addr inet6Address netip.Addr
broadcastAddr netip.Addr broadcastAddr netip.Addr
udpTimeout time.Duration inet4LoopbackAddress []netip.Addr
tcpListener net.Listener inet6LoopbackAddress []netip.Addr
tcpListener6 net.Listener udpTimeout time.Duration
tcpPort uint16 tcpListener net.Listener
tcpPort6 uint16 tcpListener6 net.Listener
tcpNat *TCPNat tcpPort uint16
udpNat *udpnat.Service tcpPort6 uint16
bindInterface bool tcpNat *TCPNat
interfaceFinder control.InterfaceFinder udpNat *udpnat.Service
frontHeadroom int bindInterface bool
txChecksumOffload bool interfaceFinder control.InterfaceFinder
frontHeadroom int
txChecksumOffload bool
} }
type Session struct { type Session struct {
@@ -58,18 +60,20 @@ type Session struct {
func NewSystem(options StackOptions) (Stack, error) { func NewSystem(options StackOptions) (Stack, error) {
stack := &System{ stack := &System{
ctx: options.Context, ctx: options.Context,
tun: options.Tun, tun: options.Tun,
tunName: options.TunOptions.Name, tunName: options.TunOptions.Name,
mtu: int(options.TunOptions.MTU), mtu: int(options.TunOptions.MTU),
udpTimeout: options.UDPTimeout, inet4LoopbackAddress: options.TunOptions.Inet4LoopbackAddress,
handler: options.Handler, inet6LoopbackAddress: options.TunOptions.Inet6LoopbackAddress,
logger: options.Logger, udpTimeout: options.UDPTimeout,
inet4Prefixes: options.TunOptions.Inet4Address, handler: options.Handler,
inet6Prefixes: options.TunOptions.Inet6Address, logger: options.Logger,
broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address), inet4Prefixes: options.TunOptions.Inet4Address,
bindInterface: options.ForwarderBindInterface, inet6Prefixes: options.TunOptions.Inet6Address,
interfaceFinder: options.InterfaceFinder, broadcastAddr: BroadcastAddr(options.TunOptions.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
} }
if len(options.TunOptions.Inet4Address) > 0 { if len(options.TunOptions.Inet4Address) > 0 {
if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) { if !HasNextAddress(options.TunOptions.Inet4Address[0], 1) {
@@ -353,18 +357,29 @@ func (s *System) processIPv4TCP(ipHdr header.IPv4, tcpHdr header.TCP) (bool, err
ipHdr.SetDestinationAddr(session.Source.Addr()) ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port()) tcpHdr.SetDestinationPort(session.Source.Port())
} else { } else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler) var loopback bool
if err != nil { for _, inet4LoopbackAddress := range s.inet4LoopbackAddress {
if errors.Is(err, ErrDrop) { if destination.Addr() == inet4LoopbackAddress {
return false, nil ipHdr.SetDestinationAddr(ipHdr.SourceAddr())
} else { ipHdr.SetSourceAddr(inet4LoopbackAddress)
return false, s.resetIPv4TCP(ipHdr, tcpHdr) loopback = true
break
} }
} }
ipHdr.SetSourceAddr(s.inet4Address) if !loopback {
tcpHdr.SetSourcePort(natPort) natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
ipHdr.SetDestinationAddr(s.inet4ServerAddress) if err != nil {
tcpHdr.SetDestinationPort(s.tcpPort) if errors.Is(err, ErrDrop) {
return false, nil
} else {
return false, s.resetIPv4TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet4Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet4ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort)
}
} }
if !s.txChecksumOffload { if !s.txChecksumOffload {
tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(0)
@@ -440,18 +455,29 @@ func (s *System) processIPv6TCP(ipHdr header.IPv6, tcpHdr header.TCP) (bool, err
ipHdr.SetDestinationAddr(session.Source.Addr()) ipHdr.SetDestinationAddr(session.Source.Addr())
tcpHdr.SetDestinationPort(session.Source.Port()) tcpHdr.SetDestinationPort(session.Source.Port())
} else { } else {
natPort, err := s.tcpNat.Lookup(source, destination, s.handler) var loopback bool
if err != nil { for _, inet6LoopbackAddress := range s.inet6LoopbackAddress {
if errors.Is(err, ErrDrop) { if destination.Addr() == inet6LoopbackAddress {
return false, nil ipHdr.SetDestinationAddr(ipHdr.SourceAddr())
} else { ipHdr.SetSourceAddr(inet6LoopbackAddress)
return false, s.resetIPv6TCP(ipHdr, tcpHdr) loopback = true
break
} }
} }
ipHdr.SetSourceAddr(s.inet6Address) if !loopback {
tcpHdr.SetSourcePort(natPort) natPort, err := s.tcpNat.Lookup(source, destination, s.handler)
ipHdr.SetDestinationAddr(s.inet6ServerAddress) if err != nil {
tcpHdr.SetDestinationPort(s.tcpPort6) if errors.Is(err, ErrDrop) {
return false, nil
} else {
return false, s.resetIPv6TCP(ipHdr, tcpHdr)
}
}
ipHdr.SetSourceAddr(s.inet6Address)
tcpHdr.SetSourcePort(natPort)
ipHdr.SetDestinationAddr(s.inet6ServerAddress)
tcpHdr.SetDestinationPort(s.tcpPort6)
}
} }
if !s.txChecksumOffload { if !s.txChecksumOffload {
tcpHdr.SetChecksum(0) tcpHdr.SetChecksum(0)

2
tun.go
View File

@@ -66,6 +66,8 @@ type Options struct {
AutoRedirectMarkMode bool AutoRedirectMarkMode bool
AutoRedirectInputMark uint32 AutoRedirectInputMark uint32
AutoRedirectOutputMark uint32 AutoRedirectOutputMark uint32
Inet4LoopbackAddress []netip.Addr
Inet6LoopbackAddress []netip.Addr
StrictRoute bool StrictRoute bool
Inet4RouteAddress []netip.Prefix Inet4RouteAddress []netip.Prefix
Inet6RouteAddress []netip.Prefix Inet6RouteAddress []netip.Prefix