- extractTTL: use 0 as initial value and take actual minimum TTL from answer records. Previously hardcoded 300 would cap higher TTLs. - TCP handler: add 90s read deadline to prevent idle connections from blocking goroutines indefinitely. - Server.Close: close upstream groups to release connection pool resources.
339 lines
7.6 KiB
Go
339 lines
7.6 KiB
Go
package minidns
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/netkits-dev/mini-dns/cache"
|
|
"github.com/netkits-dev/mini-dns/resolver"
|
|
"github.com/netkits-dev/mini-dns/rules"
|
|
"github.com/netkits-dev/mini-dns/speedtest"
|
|
"github.com/netkits-dev/mini-dns/upstream"
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
type Server struct {
|
|
config *Config
|
|
resolver *resolver.Resolver
|
|
cache *cache.Cache
|
|
revCache *cache.ReverseCache
|
|
groups map[string]*upstream.Group
|
|
udpConn net.PacketConn
|
|
tcpLn net.Listener
|
|
logger Logger
|
|
done chan struct{}
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
type Option func(*Server)
|
|
|
|
func WithProtectFunc(f func(fd int)) Option {
|
|
return func(s *Server) {
|
|
upstream.ProtectFunc = f
|
|
}
|
|
}
|
|
|
|
func WithLogger(l Logger) Option {
|
|
return func(s *Server) {
|
|
s.logger = l
|
|
}
|
|
}
|
|
|
|
func New(config *Config, opts ...Option) (*Server, error) {
|
|
s := &Server{
|
|
config: config,
|
|
logger: NewLogger(config.Log.Level),
|
|
done: make(chan struct{}),
|
|
}
|
|
for _, opt := range opts {
|
|
opt(s)
|
|
}
|
|
|
|
// Build upstreams
|
|
upstreamMap := make(map[string]upstream.Upstream)
|
|
for _, uc := range config.Upstreams {
|
|
// Determine dial function: Dial field > Detour string > nil (direct)
|
|
var dial upstream.DialFunc
|
|
if uc.Dial != nil {
|
|
dial = upstream.DialFunc(uc.Dial)
|
|
} else if uc.Detour != "" {
|
|
proxyDial, err := upstream.ProxyDialFunc(uc.Detour, 5*time.Second)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("upstream %q detour: %w", uc.Name, err)
|
|
}
|
|
dial = proxyDial
|
|
}
|
|
var u upstream.Upstream
|
|
switch uc.Protocol {
|
|
case "udp":
|
|
if dial != nil {
|
|
return nil, fmt.Errorf("upstream %q: UDP cannot be tunneled through a proxy, use protocol \"tcp\" or \"dot\" instead", uc.Name)
|
|
}
|
|
u = upstream.NewUDP(uc.Name, uc.Addr, uc.Port, nil)
|
|
case "tcp":
|
|
u = upstream.NewTCP(uc.Name, uc.Addr, uc.Port, dial)
|
|
case "dot":
|
|
u = upstream.NewDoT(uc.Name, uc.Addr, uc.Port, dial)
|
|
case "doh":
|
|
if uc.Detour != "" || uc.Dial != nil {
|
|
u = upstream.NewDoH(uc.Name, uc.Addr, uc.Port, uc.URL, dial)
|
|
} else {
|
|
u = upstream.NewDoH(uc.Name, uc.Addr, uc.Port, uc.URL)
|
|
}
|
|
default:
|
|
return nil, fmt.Errorf("unknown protocol %q for upstream %q", uc.Protocol, uc.Name)
|
|
}
|
|
upstreamMap[uc.Name] = u
|
|
}
|
|
|
|
// Build groups
|
|
groups := make(map[string]*upstream.Group)
|
|
for _, gc := range config.Groups {
|
|
var members []upstream.Upstream
|
|
for _, name := range gc.Upstreams {
|
|
u, ok := upstreamMap[name]
|
|
if !ok {
|
|
return nil, fmt.Errorf("upstream %q not found (referenced by group %q)", name, gc.Name)
|
|
}
|
|
members = append(members, u)
|
|
}
|
|
groups[gc.Name] = upstream.NewGroup(gc.Name, members, upstream.ParseStrategy(gc.Strategy))
|
|
}
|
|
|
|
// Default group is the first one
|
|
defaultGroup := ""
|
|
if len(config.Groups) > 0 {
|
|
defaultGroup = config.Groups[0].Name
|
|
}
|
|
|
|
// Build rules router — validate group references
|
|
var ruleConfigs []rules.Rule
|
|
for _, rc := range config.Rules {
|
|
if rc.Group != "" {
|
|
if _, ok := groups[rc.Group]; !ok {
|
|
return nil, fmt.Errorf("rule references non-existent group %q", rc.Group)
|
|
}
|
|
}
|
|
ruleConfigs = append(ruleConfigs, rules.Rule{
|
|
Domain: rc.Domain,
|
|
DomainSuffix: rc.DomainSuffix,
|
|
DomainKeyword: rc.DomainKeyword,
|
|
Group: rc.Group,
|
|
})
|
|
}
|
|
router := rules.NewRouter(rules.Config{
|
|
Rules: ruleConfigs,
|
|
DefaultGroup: defaultGroup,
|
|
})
|
|
|
|
// Build cache
|
|
var dnsCache *cache.Cache
|
|
if config.Cache.Enabled {
|
|
dnsCache = cache.New(cache.Config{
|
|
MaxSize: config.Cache.Size,
|
|
MinTTL: time.Duration(config.Cache.MinTTL) * time.Second,
|
|
MaxTTL: time.Duration(config.Cache.MaxTTL) * time.Second,
|
|
Prefetch: config.Cache.Prefetch,
|
|
ServeStale: config.Cache.ServeStale,
|
|
})
|
|
}
|
|
|
|
// Build reverse cache
|
|
revCache := cache.NewReverseCache()
|
|
|
|
// Build pollution filter
|
|
var filter *resolver.PollutionFilter
|
|
if config.Pollution.BogonFilter || len(config.Pollution.Blacklist) > 0 {
|
|
filter = resolver.NewPollutionFilter(config.Pollution.BogonFilter, config.Pollution.Blacklist)
|
|
}
|
|
|
|
// Build speed tester
|
|
tester := speedtest.New(speedtest.Config{
|
|
Enabled: config.SpeedTest.Enabled,
|
|
TimeoutMs: config.SpeedTest.TimeoutMs,
|
|
Port: config.SpeedTest.Port,
|
|
})
|
|
|
|
// Build resolver
|
|
s.resolver = resolver.New(resolver.Config{
|
|
Groups: groups,
|
|
Rules: router,
|
|
Cache: dnsCache,
|
|
RevCache: revCache,
|
|
Filter: filter,
|
|
SpeedTester: tester,
|
|
Hosts: config.Hosts,
|
|
PreferIPv4: config.PreferIPv4,
|
|
Logger: s.logger,
|
|
})
|
|
s.cache = dnsCache
|
|
s.revCache = revCache
|
|
s.groups = groups
|
|
|
|
return s, nil
|
|
}
|
|
|
|
func NewFromJSON(data []byte, opts ...Option) (*Server, error) {
|
|
config, err := ParseConfig(data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return New(config, opts...)
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
addr := s.config.Listen
|
|
|
|
// UDP listener
|
|
udpConn, err := net.ListenPacket("udp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen udp %s: %w", addr, err)
|
|
}
|
|
s.udpConn = udpConn
|
|
|
|
// TCP listener
|
|
tcpLn, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
udpConn.Close()
|
|
return fmt.Errorf("listen tcp %s: %w", addr, err)
|
|
}
|
|
s.tcpLn = tcpLn
|
|
|
|
s.logger.Info("mini-dns listening on ", addr)
|
|
|
|
go s.serveUDP()
|
|
go s.serveTCP()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) Exchange(query []byte) ([]byte, error) {
|
|
msg := new(dns.Msg)
|
|
if err := msg.Unpack(query); err != nil {
|
|
return nil, err
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
resp, err := s.resolver.Resolve(ctx, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp.Pack()
|
|
}
|
|
|
|
func (s *Server) Resolve(domain string) ([]netip.Addr, error) {
|
|
return s.resolver.ResolveAddr(domain)
|
|
}
|
|
|
|
func (s *Server) ReverseCache() *cache.ReverseCache {
|
|
return s.revCache
|
|
}
|
|
|
|
func (s *Server) Addr() string {
|
|
if s.udpConn != nil {
|
|
return s.udpConn.LocalAddr().String()
|
|
}
|
|
return s.config.Listen
|
|
}
|
|
|
|
func (s *Server) Close() error {
|
|
close(s.done)
|
|
if s.udpConn != nil {
|
|
s.udpConn.Close()
|
|
}
|
|
if s.tcpLn != nil {
|
|
s.tcpLn.Close()
|
|
}
|
|
// Wait for in-flight requests (with timeout)
|
|
done := make(chan struct{})
|
|
go func() { s.wg.Wait(); close(done) }()
|
|
select {
|
|
case <-done:
|
|
case <-time.After(5 * time.Second):
|
|
}
|
|
if s.cache != nil {
|
|
s.cache.Close()
|
|
}
|
|
if s.revCache != nil {
|
|
s.revCache.Close()
|
|
}
|
|
for _, g := range s.groups {
|
|
g.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) serveUDP() {
|
|
buf := make([]byte, 4096)
|
|
for {
|
|
n, addr, err := s.udpConn.ReadFrom(buf)
|
|
if err != nil {
|
|
select {
|
|
case <-s.done:
|
|
return
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
s.wg.Add(1)
|
|
go func(query []byte, addr net.Addr) {
|
|
defer s.wg.Done()
|
|
resp, err := s.Exchange(query)
|
|
if err != nil {
|
|
s.logger.Debug("udp resolve error: ", err)
|
|
return
|
|
}
|
|
s.udpConn.WriteTo(resp, addr)
|
|
}(append([]byte(nil), buf[:n]...), addr)
|
|
}
|
|
}
|
|
|
|
func (s *Server) serveTCP() {
|
|
for {
|
|
conn, err := s.tcpLn.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-s.done:
|
|
return
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
s.wg.Add(1)
|
|
go func() {
|
|
defer s.wg.Done()
|
|
s.handleTCP(conn)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleTCP(conn net.Conn) {
|
|
defer conn.Close()
|
|
dnsConn := &dns.Conn{Conn: conn}
|
|
for {
|
|
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
|
|
msg, err := dnsConn.ReadMsg()
|
|
if err != nil {
|
|
return
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
resp, err := s.resolver.Resolve(ctx, msg)
|
|
cancel()
|
|
if err != nil {
|
|
s.logger.Debug("tcp resolve error: ", err)
|
|
// Return SERVFAIL instead of dropping the connection
|
|
resp = new(dns.Msg)
|
|
resp.SetRcode(msg, dns.RcodeServerFailure)
|
|
}
|
|
if err := dnsConn.WriteMsg(resp); err != nil {
|
|
s.logger.Debug("tcp write error: ", err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|