diff --git a/common/domain/matcher.go b/common/domain/matcher.go index 1f328b8c..52c32a35 100644 --- a/common/domain/matcher.go +++ b/common/domain/matcher.go @@ -12,7 +12,7 @@ type Matcher struct { set *succinctSet } -func NewMatcher(domains []string, domainSuffix []string) *Matcher { +func NewMatcher(domains []string, domainSuffix []string, generateLegacy bool) *Matcher { domainList := make([]string, 0, len(domains)+2*len(domainSuffix)) seen := make(map[string]bool, len(domainList)) for _, domain := range domainSuffix { @@ -22,9 +22,15 @@ func NewMatcher(domains []string, domainSuffix []string) *Matcher { seen[domain] = true if domain[0] == '.' { domainList = append(domainList, reverseDomainSuffix(domain)) - } else { + } else if generateLegacy { domainList = append(domainList, reverseDomain(domain)) - domainList = append(domainList, reverseRootDomainSuffix(domain)) + suffixDomain := "." + domain + if !seen[suffixDomain] { + seen[suffixDomain] = true + domainList = append(domainList, reverseDomainSuffix(suffixDomain)) + } + } else { + domainList = append(domainList, reverseDomainRoot(domain)) } } for _, domain := range domains { @@ -79,6 +85,8 @@ func (m *Matcher) Dump() (domainList []string, prefixList []string) { key = reverseDomain(key) if key[0] == prefixLabel { prefixMap[key[1:]] = true + } else if key[0] == rootLabel { + prefixList = append(prefixList, key[1:]) } else { domainMap[key] = true } @@ -124,15 +132,14 @@ func reverseDomainSuffix(domain string) string { return string(b) } -func reverseRootDomainSuffix(domain string) string { +func reverseDomainRoot(domain string) string { l := len(domain) - b := make([]byte, l+2) + b := make([]byte, l+1) for i := 0; i < l; { r, n := utf8.DecodeRuneInString(domain[i:]) i += n utf8.EncodeRune(b[l-i:], r) } - b[l] = '.' - b[l+1] = prefixLabel + b[l] = rootLabel return string(b) } diff --git a/common/domain/matcher_test.go b/common/domain/matcher_test.go index f7b4cc17..5d12e91a 100644 --- a/common/domain/matcher_test.go +++ b/common/domain/matcher_test.go @@ -14,7 +14,26 @@ import ( func TestMatcher(t *testing.T) { testDomain := []string{"example.com", "example.org"} testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"} - matcher := domain.NewMatcher(testDomain, testDomainSuffix) + matcher := domain.NewMatcher(testDomain, testDomainSuffix, false) + require.NotNil(t, matcher) + require.True(t, matcher.Match("example.com")) + require.True(t, matcher.Match("example.org")) + require.False(t, matcher.Match("example.cn")) + require.True(t, matcher.Match("example.com.cn")) + require.True(t, matcher.Match("example.org.cn")) + require.False(t, matcher.Match("com.cn")) + require.False(t, matcher.Match("org.cn")) + require.True(t, matcher.Match("sagernet.org")) + require.True(t, matcher.Match("sing-box.sagernet.org")) + dDomain, dDomainSuffix := matcher.Dump() + require.Equal(t, testDomain, dDomain) + require.Equal(t, testDomainSuffix, dDomainSuffix) +} + +func TestMatcherLegacy(t *testing.T) { + testDomain := []string{"example.com", "example.org"} + testDomainSuffix := []string{".com.cn", ".org.cn", "sagernet.org"} + matcher := domain.NewMatcher(testDomain, testDomainSuffix, true) require.NotNil(t, matcher) require.True(t, matcher.Match("example.com")) require.True(t, matcher.Match("example.org")) @@ -50,7 +69,7 @@ func TestDumpLarge(t *testing.T) { require.True(t, len(domainList)+len(domainSuffixList) > 0) sort.Strings(domainList) sort.Strings(domainSuffixList) - matcher := domain.NewMatcher(domainList, domainSuffixList) + matcher := domain.NewMatcher(domainList, domainSuffixList, false) require.NotNil(t, matcher) dDomain, dDomainSuffix := matcher.Dump() require.Equal(t, domainList, dDomain) diff --git a/common/domain/set.go b/common/domain/set.go index ae1e00b1..2072e1db 100644 --- a/common/domain/set.go +++ b/common/domain/set.go @@ -4,7 +4,10 @@ import ( "math/bits" ) -const prefixLabel = '\r' +const ( + prefixLabel = '\r' + rootLabel = '\n' +) // mod from https://github.com/openacid/succinct @@ -54,6 +57,13 @@ func (ss *succinctSet) Has(key string) bool { if nextLabel == prefixLabel { return true } + if nextLabel == rootLabel { + nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1) + hasNext := getBit(ss.leaves, nextNodeId) != 0 + if currentChar == '.' && hasNext { + return true + } + } if nextLabel == currentChar { break } @@ -68,7 +78,8 @@ func (ss *succinctSet) Has(key string) bool { if getBit(ss.labelBitmap, bmIdx) != 0 { return false } - if ss.labels[bmIdx-nodeId] == prefixLabel { + nextLabel := ss.labels[bmIdx-nodeId] + if nextLabel == prefixLabel || nextLabel == rootLabel { return true } }