Files
mini-dns/server.go
NeoMody ec9cdb2784 Fix cache TTL extraction, add TCP idle timeout, close upstreams on shutdown
- extractTTL: use 0 as initial value and take actual minimum TTL from
  answer records. Previously hardcoded 300 would cap higher TTLs.
- TCP handler: add 90s read deadline to prevent idle connections from
  blocking goroutines indefinitely.
- Server.Close: close upstream groups to release connection pool resources.
2026-04-02 05:42:20 +08:00

339 lines
7.6 KiB
Go

package minidns
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/netkits-dev/mini-dns/cache"
"github.com/netkits-dev/mini-dns/resolver"
"github.com/netkits-dev/mini-dns/rules"
"github.com/netkits-dev/mini-dns/speedtest"
"github.com/netkits-dev/mini-dns/upstream"
"github.com/miekg/dns"
)
type Server struct {
config *Config
resolver *resolver.Resolver
cache *cache.Cache
revCache *cache.ReverseCache
groups map[string]*upstream.Group
udpConn net.PacketConn
tcpLn net.Listener
logger Logger
done chan struct{}
wg sync.WaitGroup
}
type Option func(*Server)
func WithProtectFunc(f func(fd int)) Option {
return func(s *Server) {
upstream.ProtectFunc = f
}
}
func WithLogger(l Logger) Option {
return func(s *Server) {
s.logger = l
}
}
func New(config *Config, opts ...Option) (*Server, error) {
s := &Server{
config: config,
logger: NewLogger(config.Log.Level),
done: make(chan struct{}),
}
for _, opt := range opts {
opt(s)
}
// Build upstreams
upstreamMap := make(map[string]upstream.Upstream)
for _, uc := range config.Upstreams {
// Determine dial function: Dial field > Detour string > nil (direct)
var dial upstream.DialFunc
if uc.Dial != nil {
dial = upstream.DialFunc(uc.Dial)
} else if uc.Detour != "" {
proxyDial, err := upstream.ProxyDialFunc(uc.Detour, 5*time.Second)
if err != nil {
return nil, fmt.Errorf("upstream %q detour: %w", uc.Name, err)
}
dial = proxyDial
}
var u upstream.Upstream
switch uc.Protocol {
case "udp":
if dial != nil {
return nil, fmt.Errorf("upstream %q: UDP cannot be tunneled through a proxy, use protocol \"tcp\" or \"dot\" instead", uc.Name)
}
u = upstream.NewUDP(uc.Name, uc.Addr, uc.Port, nil)
case "tcp":
u = upstream.NewTCP(uc.Name, uc.Addr, uc.Port, dial)
case "dot":
u = upstream.NewDoT(uc.Name, uc.Addr, uc.Port, dial)
case "doh":
if uc.Detour != "" || uc.Dial != nil {
u = upstream.NewDoH(uc.Name, uc.Addr, uc.Port, uc.URL, dial)
} else {
u = upstream.NewDoH(uc.Name, uc.Addr, uc.Port, uc.URL)
}
default:
return nil, fmt.Errorf("unknown protocol %q for upstream %q", uc.Protocol, uc.Name)
}
upstreamMap[uc.Name] = u
}
// Build groups
groups := make(map[string]*upstream.Group)
for _, gc := range config.Groups {
var members []upstream.Upstream
for _, name := range gc.Upstreams {
u, ok := upstreamMap[name]
if !ok {
return nil, fmt.Errorf("upstream %q not found (referenced by group %q)", name, gc.Name)
}
members = append(members, u)
}
groups[gc.Name] = upstream.NewGroup(gc.Name, members, upstream.ParseStrategy(gc.Strategy))
}
// Default group is the first one
defaultGroup := ""
if len(config.Groups) > 0 {
defaultGroup = config.Groups[0].Name
}
// Build rules router — validate group references
var ruleConfigs []rules.Rule
for _, rc := range config.Rules {
if rc.Group != "" {
if _, ok := groups[rc.Group]; !ok {
return nil, fmt.Errorf("rule references non-existent group %q", rc.Group)
}
}
ruleConfigs = append(ruleConfigs, rules.Rule{
Domain: rc.Domain,
DomainSuffix: rc.DomainSuffix,
DomainKeyword: rc.DomainKeyword,
Group: rc.Group,
})
}
router := rules.NewRouter(rules.Config{
Rules: ruleConfigs,
DefaultGroup: defaultGroup,
})
// Build cache
var dnsCache *cache.Cache
if config.Cache.Enabled {
dnsCache = cache.New(cache.Config{
MaxSize: config.Cache.Size,
MinTTL: time.Duration(config.Cache.MinTTL) * time.Second,
MaxTTL: time.Duration(config.Cache.MaxTTL) * time.Second,
Prefetch: config.Cache.Prefetch,
ServeStale: config.Cache.ServeStale,
})
}
// Build reverse cache
revCache := cache.NewReverseCache()
// Build pollution filter
var filter *resolver.PollutionFilter
if config.Pollution.BogonFilter || len(config.Pollution.Blacklist) > 0 {
filter = resolver.NewPollutionFilter(config.Pollution.BogonFilter, config.Pollution.Blacklist)
}
// Build speed tester
tester := speedtest.New(speedtest.Config{
Enabled: config.SpeedTest.Enabled,
TimeoutMs: config.SpeedTest.TimeoutMs,
Port: config.SpeedTest.Port,
})
// Build resolver
s.resolver = resolver.New(resolver.Config{
Groups: groups,
Rules: router,
Cache: dnsCache,
RevCache: revCache,
Filter: filter,
SpeedTester: tester,
Hosts: config.Hosts,
PreferIPv4: config.PreferIPv4,
Logger: s.logger,
})
s.cache = dnsCache
s.revCache = revCache
s.groups = groups
return s, nil
}
func NewFromJSON(data []byte, opts ...Option) (*Server, error) {
config, err := ParseConfig(data)
if err != nil {
return nil, err
}
return New(config, opts...)
}
func (s *Server) Start() error {
addr := s.config.Listen
// UDP listener
udpConn, err := net.ListenPacket("udp", addr)
if err != nil {
return fmt.Errorf("listen udp %s: %w", addr, err)
}
s.udpConn = udpConn
// TCP listener
tcpLn, err := net.Listen("tcp", addr)
if err != nil {
udpConn.Close()
return fmt.Errorf("listen tcp %s: %w", addr, err)
}
s.tcpLn = tcpLn
s.logger.Info("mini-dns listening on ", addr)
go s.serveUDP()
go s.serveTCP()
return nil
}
func (s *Server) Exchange(query []byte) ([]byte, error) {
msg := new(dns.Msg)
if err := msg.Unpack(query); err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
resp, err := s.resolver.Resolve(ctx, msg)
if err != nil {
return nil, err
}
return resp.Pack()
}
func (s *Server) Resolve(domain string) ([]netip.Addr, error) {
return s.resolver.ResolveAddr(domain)
}
func (s *Server) ReverseCache() *cache.ReverseCache {
return s.revCache
}
func (s *Server) Addr() string {
if s.udpConn != nil {
return s.udpConn.LocalAddr().String()
}
return s.config.Listen
}
func (s *Server) Close() error {
close(s.done)
if s.udpConn != nil {
s.udpConn.Close()
}
if s.tcpLn != nil {
s.tcpLn.Close()
}
// Wait for in-flight requests (with timeout)
done := make(chan struct{})
go func() { s.wg.Wait(); close(done) }()
select {
case <-done:
case <-time.After(5 * time.Second):
}
if s.cache != nil {
s.cache.Close()
}
if s.revCache != nil {
s.revCache.Close()
}
for _, g := range s.groups {
g.Close()
}
return nil
}
func (s *Server) serveUDP() {
buf := make([]byte, 4096)
for {
n, addr, err := s.udpConn.ReadFrom(buf)
if err != nil {
select {
case <-s.done:
return
default:
continue
}
}
s.wg.Add(1)
go func(query []byte, addr net.Addr) {
defer s.wg.Done()
resp, err := s.Exchange(query)
if err != nil {
s.logger.Debug("udp resolve error: ", err)
return
}
s.udpConn.WriteTo(resp, addr)
}(append([]byte(nil), buf[:n]...), addr)
}
}
func (s *Server) serveTCP() {
for {
conn, err := s.tcpLn.Accept()
if err != nil {
select {
case <-s.done:
return
default:
continue
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleTCP(conn)
}()
}
}
func (s *Server) handleTCP(conn net.Conn) {
defer conn.Close()
dnsConn := &dns.Conn{Conn: conn}
for {
conn.SetReadDeadline(time.Now().Add(90 * time.Second))
msg, err := dnsConn.ReadMsg()
if err != nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
resp, err := s.resolver.Resolve(ctx, msg)
cancel()
if err != nil {
s.logger.Debug("tcp resolve error: ", err)
// Return SERVFAIL instead of dropping the connection
resp = new(dns.Msg)
resp.SetRcode(msg, dns.RcodeServerFailure)
}
if err := dnsConn.WriteMsg(resp); err != nil {
s.logger.Debug("tcp write error: ", err)
return
}
}
}