diff --git a/internal/cli/generate.go b/internal/cli/generate.go index b776c13..d7d7c68 100644 --- a/internal/cli/generate.go +++ b/internal/cli/generate.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "rulekit/internal/engine" + "rulekit/internal/model" "rulekit/internal/writer" "strings" @@ -20,6 +21,7 @@ func init() { } generateCmd.Flags().StringP("format", "f", "singbox,clash,surge", "output formats (comma-separated: singbox,srs,clash,surge)") generateCmd.Flags().StringP("output", "o", "", "output directory (default: rules_dir/output)") + generateCmd.Flags().BoolP("optimize", "O", false, "optimize: merge covered domains, aggregate CIDRs") rootCmd.AddCommand(generateCmd) } @@ -27,6 +29,7 @@ func runGenerate(cmd *cobra.Command, args []string) error { categoryName := args[0] formatStr, _ := cmd.Flags().GetString("format") outputDir, _ := cmd.Flags().GetString("output") + optimize, _ := cmd.Flags().GetBool("optimize") cfg := loadConfig() @@ -37,11 +40,24 @@ func runGenerate(cmd *cobra.Command, args []string) error { return err } - merged, err := engine.Merge(cfg, categoryName) - if err != nil { - return err + var merged *model.MergedRuleSet + if optimize { + var optResult *engine.OptimizeResult + var err error + merged, optResult, err = engine.MergeOptimized(cfg, categoryName) + if err != nil { + return err + } + fmt.Printf("Merged %s: %d rules (optimized %d -> %d)\n", + categoryName, len(merged.Rules), optResult.Before, optResult.After) + } else { + var err error + merged, err = engine.Merge(cfg, categoryName) + if err != nil { + return err + } + fmt.Printf("Merged %s: %d rules\n", categoryName, len(merged.Rules)) } - fmt.Printf("Merged %s: %d rules\n", categoryName, len(merged.Rules)) formats := strings.Split(formatStr, ",") for _, fmtName := range formats { diff --git a/internal/cli/merge.go b/internal/cli/merge.go index 3b6d77c..e3d1fcb 100644 --- a/internal/cli/merge.go +++ b/internal/cli/merge.go @@ -19,6 +19,7 @@ func init() { } mergeCmd.Flags().IntP("limit", "n", 0, "limit output rows (0 = all)") mergeCmd.Flags().BoolP("stats", "s", false, "show statistics only") + mergeCmd.Flags().BoolP("optimize", "O", false, "optimize: merge covered domains, aggregate CIDRs") rootCmd.AddCommand(mergeCmd) } @@ -26,12 +27,25 @@ func runMerge(cmd *cobra.Command, args []string) error { categoryName := args[0] limit, _ := cmd.Flags().GetInt("limit") statsOnly, _ := cmd.Flags().GetBool("stats") + optimize, _ := cmd.Flags().GetBool("optimize") cfg := loadConfig() - merged, err := engine.Merge(cfg, categoryName) - if err != nil { - return err + var merged *model.MergedRuleSet + var optResult *engine.OptimizeResult + + if optimize { + var err error + merged, optResult, err = engine.MergeOptimized(cfg, categoryName) + if err != nil { + return err + } + } else { + var err error + merged, err = engine.Merge(cfg, categoryName) + if err != nil { + return err + } } // Stats @@ -41,6 +55,11 @@ func runMerge(cmd *cobra.Command, args []string) error { } fmt.Printf("Merged: %s (%d rules)\n", merged.Name, len(merged.Rules)) + if optResult != nil { + fmt.Printf("Optimized: %d -> %d (-%d domains by suffix, -%d by keyword, -%d CIDRs merged)\n", + optResult.Before, optResult.After, + optResult.DomainsMerged, optResult.KeywordMerged, optResult.CIDRsMerged) + } for t, c := range types { fmt.Printf(" %s: %d\n", t, c) } diff --git a/internal/engine/merge.go b/internal/engine/merge.go index 7b9e4bf..2566e34 100644 --- a/internal/engine/merge.go +++ b/internal/engine/merge.go @@ -60,6 +60,17 @@ func Merge(cfg *config.Config, categoryName string) (*model.MergedRuleSet, error }, nil } +// MergeOptimized merges and then optimizes (domain/suffix/CIDR dedup). +func MergeOptimized(cfg *config.Config, categoryName string) (*model.MergedRuleSet, *OptimizeResult, error) { + merged, err := Merge(cfg, categoryName) + if err != nil { + return nil, nil, err + } + optimized, result := Optimize(merged.Rules) + merged.Rules = optimized + return merged, result, nil +} + func typeOrder(t model.RuleType) int { switch t { case model.RuleDomain: diff --git a/internal/engine/optimize.go b/internal/engine/optimize.go new file mode 100644 index 0000000..461e23b --- /dev/null +++ b/internal/engine/optimize.go @@ -0,0 +1,365 @@ +package engine + +import ( + "encoding/binary" + "fmt" + "net" + "rulekit/internal/model" + "sort" + "strings" +) + +// OptimizeResult tracks what the optimizer did. +type OptimizeResult struct { + Before int + After int + DomainsMerged int // domain entries removed because a domain_suffix covers them + KeywordMerged int // domain/suffix entries removed because a keyword covers them + CIDRsMerged int // CIDR entries merged into larger blocks + Removals []string // human-readable list of what was removed/merged +} + +// Optimize performs semantic deduplication on a merged rule set: +// 1. domain_suffix subsumes matching domain entries +// 2. domain_keyword subsumes matching domain/suffix entries +// 3. CIDR aggregation: merge adjacent/contained IP ranges +func Optimize(rules []model.Rule) ([]model.Rule, *OptimizeResult) { + result := &OptimizeResult{Before: len(rules)} + + // Separate by type + var domains, suffixes, keywords, regexes, cidrs, procs []model.Rule + for _, r := range rules { + switch r.Type { + case model.RuleDomain: + domains = append(domains, r) + case model.RuleDomainSuffix: + suffixes = append(suffixes, r) + case model.RuleDomainKeyword: + keywords = append(keywords, r) + case model.RuleDomainRegex: + regexes = append(regexes, r) + case model.RuleIPCIDR: + cidrs = append(cidrs, r) + case model.RuleProcessName: + procs = append(procs, r) + } + } + + // Build suffix set for fast lookup + suffixSet := map[string]bool{} + for _, s := range suffixes { + suffixSet[s.Value] = true + } + + // Build keyword list + keywordVals := make([]string, len(keywords)) + for i, k := range keywords { + keywordVals[i] = k.Value + } + + // 1. Remove domains covered by domain_suffix + var filteredDomains []model.Rule + for _, d := range domains { + if coveredBySuffix(d.Value, suffixSet) { + result.DomainsMerged++ + result.Removals = append(result.Removals, + fmt.Sprintf("domain:%s (covered by suffix)", d.Value)) + } else { + filteredDomains = append(filteredDomains, d) + } + } + + // 2. Remove domains/suffixes covered by keyword + var filteredDomains2 []model.Rule + for _, d := range filteredDomains { + if coveredByKeyword(d.Value, keywordVals) { + result.KeywordMerged++ + result.Removals = append(result.Removals, + fmt.Sprintf("domain:%s (covered by keyword)", d.Value)) + } else { + filteredDomains2 = append(filteredDomains2, d) + } + } + + var filteredSuffixes []model.Rule + for _, s := range suffixes { + if coveredByKeyword(s.Value, keywordVals) { + result.KeywordMerged++ + result.Removals = append(result.Removals, + fmt.Sprintf("domain_suffix:%s (covered by keyword)", s.Value)) + } else { + filteredSuffixes = append(filteredSuffixes, s) + } + } + + // 2b. Remove suffixes covered by a parent suffix + // e.g., "a.bilibili.com" is redundant if "bilibili.com" exists + filteredSuffixes = removeCoveredSuffixes(filteredSuffixes, result) + + // 3. CIDR aggregation + optimizedCIDRs := aggregateCIDRs(cidrs, result) + + // Reassemble + var out []model.Rule + out = append(out, filteredDomains2...) + out = append(out, filteredSuffixes...) + out = append(out, keywords...) + out = append(out, regexes...) + out = append(out, optimizedCIDRs...) + out = append(out, procs...) + + sort.Slice(out, func(i, j int) bool { + if out[i].Type != out[j].Type { + return typeOrder(out[i].Type) < typeOrder(out[j].Type) + } + return out[i].Value < out[j].Value + }) + + result.After = len(out) + return out, result +} + +// coveredBySuffix checks if a domain is matched by any suffix in the set. +// e.g., "www.bilibili.com" is covered by suffix "bilibili.com" +func coveredBySuffix(domain string, suffixSet map[string]bool) bool { + // Check exact match first + if suffixSet[domain] { + return true + } + // Walk up the domain tree + parts := strings.Split(domain, ".") + for i := 1; i < len(parts); i++ { + parent := strings.Join(parts[i:], ".") + if suffixSet[parent] { + return true + } + } + return false +} + +// coveredByKeyword checks if a domain/suffix contains any keyword. +func coveredByKeyword(value string, keywords []string) bool { + for _, kw := range keywords { + if strings.Contains(value, kw) { + return true + } + } + return false +} + +// removeCoveredSuffixes removes suffixes that are subdomains of other suffixes. +// e.g., "api.bilibili.com" is redundant if "bilibili.com" exists as a suffix. +func removeCoveredSuffixes(suffixes []model.Rule, result *OptimizeResult) []model.Rule { + suffixSet := map[string]bool{} + for _, s := range suffixes { + suffixSet[s.Value] = true + } + + var filtered []model.Rule + for _, s := range suffixes { + parts := strings.Split(s.Value, ".") + covered := false + for i := 1; i < len(parts); i++ { + parent := strings.Join(parts[i:], ".") + if suffixSet[parent] { + covered = true + result.DomainsMerged++ + result.Removals = append(result.Removals, + fmt.Sprintf("domain_suffix:%s (covered by suffix:%s)", s.Value, parent)) + break + } + } + if !covered { + filtered = append(filtered, s) + } + } + return filtered +} + +// aggregateCIDRs merges adjacent and contained CIDR blocks. +func aggregateCIDRs(cidrs []model.Rule, result *OptimizeResult) []model.Rule { + if len(cidrs) == 0 { + return nil + } + + // Separate IPv4 and IPv6 + var v4nets, v6nets []*net.IPNet + cidrSource := map[string]model.Rule{} // keep first rule for metadata + + for _, r := range cidrs { + _, ipnet, err := net.ParseCIDR(r.Value) + if err != nil { + continue + } + key := ipnet.String() + if _, exists := cidrSource[key]; !exists { + cidrSource[key] = r + } + if ipnet.IP.To4() != nil { + v4nets = append(v4nets, ipnet) + } else { + v6nets = append(v6nets, ipnet) + } + } + + // Remove contained CIDRs and merge adjacent + v4merged := mergeCIDRList(v4nets) + v6merged := mergeCIDRList(v6nets) + + mergedCount := (len(v4nets) + len(v6nets)) - (len(v4merged) + len(v6merged)) + result.CIDRsMerged = mergedCount + + var out []model.Rule + for _, n := range v4merged { + cidrStr := n.String() + if r, ok := cidrSource[cidrStr]; ok { + out = append(out, r) + } else { + out = append(out, model.Rule{ + Type: model.RuleIPCIDR, + Value: cidrStr, + Source: "optimized", + }) + } + } + for _, n := range v6merged { + cidrStr := n.String() + if r, ok := cidrSource[cidrStr]; ok { + out = append(out, r) + } else { + out = append(out, model.Rule{ + Type: model.RuleIPCIDR, + Value: cidrStr, + Source: "optimized", + }) + } + } + return out +} + +// mergeCIDRList removes contained CIDRs and merges adjacent ones. +func mergeCIDRList(nets []*net.IPNet) []*net.IPNet { + if len(nets) == 0 { + return nil + } + + // Sort by IP then prefix length + sort.Slice(nets, func(i, j int) bool { + cmp := compareIPs(nets[i].IP, nets[j].IP) + if cmp != 0 { + return cmp < 0 + } + iOnes, _ := nets[i].Mask.Size() + jOnes, _ := nets[j].Mask.Size() + return iOnes < jOnes // shorter prefix (larger range) first + }) + + // Remove contained + var deduped []*net.IPNet + for _, n := range nets { + contained := false + for _, existing := range deduped { + if existing.Contains(n.IP) { + onesE, _ := existing.Mask.Size() + onesN, _ := n.Mask.Size() + if onesE <= onesN { // existing has equal or larger range + contained = true + break + } + } + } + if !contained { + deduped = append(deduped, n) + } + } + + // Try merging adjacent pairs + changed := true + for changed { + changed = false + var merged []*net.IPNet + skip := map[int]bool{} + for i := 0; i < len(deduped); i++ { + if skip[i] { + continue + } + didMerge := false + for j := i + 1; j < len(deduped); j++ { + if skip[j] { + continue + } + if combined := tryCombine(deduped[i], deduped[j]); combined != nil { + merged = append(merged, combined) + skip[i] = true + skip[j] = true + changed = true + didMerge = true + break + } + } + if !didMerge && !skip[i] { + merged = append(merged, deduped[i]) + } + } + deduped = merged + } + + return deduped +} + +// tryCombine tries to merge two adjacent CIDRs into one. +// e.g., 1.0.0.0/24 + 1.0.1.0/24 = 1.0.0.0/23 +func tryCombine(a, b *net.IPNet) *net.IPNet { + onesA, bitsA := a.Mask.Size() + onesB, bitsB := b.Mask.Size() + if onesA != onesB || bitsA != bitsB { + return nil + } + if onesA == 0 { + return nil + } + + // The parent prefix is one bit shorter + parentOnes := onesA - 1 + parentMask := net.CIDRMask(parentOnes, bitsA) + + // Check if both belong to the same parent + parentA := a.IP.Mask(parentMask) + parentB := b.IP.Mask(parentMask) + + if !parentA.Equal(parentB) { + return nil + } + + return &net.IPNet{ + IP: parentA, + Mask: parentMask, + } +} + +func compareIPs(a, b net.IP) int { + a4 := a.To4() + b4 := b.To4() + if a4 != nil && b4 != nil { + ai := binary.BigEndian.Uint32(a4) + bi := binary.BigEndian.Uint32(b4) + if ai < bi { + return -1 + } + if ai > bi { + return 1 + } + return 0 + } + a16 := a.To16() + b16 := b.To16() + for i := 0; i < 16; i++ { + if a16[i] < b16[i] { + return -1 + } + if a16[i] > b16[i] { + return 1 + } + } + return 0 +}