diff --git a/cache.go b/cache.go index e27cd28..085b04c 100644 --- a/cache.go +++ b/cache.go @@ -1,42 +1,81 @@ package htmlquery import ( + "errors" "sync" "github.com/antchfx/xpath" "github.com/golang/groupcache/lru" ) -// DisableSelectorCache will disable caching for the query selector if value is true. -var DisableSelectorCache = false +type XpathQueryLookup interface { + GetQuery(expr string) (*xpath.Expr, error) +} + +var ( + // DisableSelectorCache will disable caching for the query selector if value is true. + DisableSelectorCache = false -// SelectorCacheMaxEntries allows how many selector object can be caching. Default is 50. -// Will disable caching if SelectorCacheMaxEntries <= 0. -var SelectorCacheMaxEntries = 50 + // SelectorCacheMaxEntries allows how many selector object can be caching. Default is 50. + // Will disable caching if SelectorCacheMaxEntries <= 0. + SelectorCacheMaxEntries = 50 +) var ( - cacheOnce sync.Once + xpcache XpathQueryLookup +) + +// max allows how many selector object can be caching. Default is 50. +// Will disable caching if max <= 0. +func NewXpathQueryLookup(max int) XpathQueryLookup { + if max == 0 { + return &nocacheXpathQueryLookup{} + } + return &lruXpathQueryLookup{ + cache: lru.New(max), + } +} + +type lruXpathQueryLookup struct { cache *lru.Cache cacheMutex sync.Mutex -) +} -func getQuery(expr string) (*xpath.Expr, error) { - if DisableSelectorCache || SelectorCacheMaxEntries <= 0 { +func (lxpl *lruXpathQueryLookup) GetQuery(expr string) (*xpath.Expr, error) { + if lxpl.cache == nil || DisableSelectorCache { return xpath.Compile(expr) } - cacheOnce.Do(func() { - cache = lru.New(SelectorCacheMaxEntries) - }) - cacheMutex.Lock() - defer cacheMutex.Unlock() - if v, ok := cache.Get(expr); ok { - return v.(*xpath.Expr), nil + + lxpl.cacheMutex.Lock() + defer lxpl.cacheMutex.Unlock() + if v, ok := lxpl.cache.Get(expr); ok { + e, ok := v.(*xpath.Expr) + if !ok { + return nil, errors.New("type asserion failed") + } + return e, nil } + v, err := xpath.Compile(expr) if err != nil { return nil, err } - cache.Add(expr, v) + lxpl.cache.Add(expr, v) return v, nil } + +type nocacheXpathQueryLookup struct{} + +func (*nocacheXpathQueryLookup) GetQuery(expr string) (*xpath.Expr, error) { + return xpath.Compile(expr) +} + +func SetCache(x XpathQueryLookup) { + xpcache = NewXpathQueryLookup(SelectorCacheMaxEntries) + +} + +func init() { + SetCache(NewXpathQueryLookup(SelectorCacheMaxEntries)) +} diff --git a/query.go b/query.go index c1c6457..f7f32d8 100644 --- a/query.go +++ b/query.go @@ -49,7 +49,7 @@ func FindOne(top *html.Node, expr string) *html.Node { // QueryAll searches the html.Node that matches by the specified XPath expr. // Return an error if the expression `expr` cannot be parsed. func QueryAll(top *html.Node, expr string) ([]*html.Node, error) { - exp, err := getQuery(expr) + exp, err := xpcache.GetQuery(expr) if err != nil { return nil, err } @@ -62,7 +62,7 @@ func QueryAll(top *html.Node, expr string) ([]*html.Node, error) { // // Returns an error if the expression `expr` cannot be parsed. func Query(top *html.Node, expr string) (*html.Node, error) { - exp, err := getQuery(expr) + exp, err := xpcache.GetQuery(expr) if err != nil { return nil, err } diff --git a/query_test.go b/query_test.go index 959c62c..b9e35e5 100644 --- a/query_test.go +++ b/query_test.go @@ -49,23 +49,24 @@ var testDoc = loadHTML(htmlSample) func BenchmarkSelectorCache(b *testing.B) { DisableSelectorCache = false for i := 0; i < b.N; i++ { - getQuery("/AAA/BBB/DDD/CCC/EEE/ancestor::*") + xpcache.GetQuery("/AAA/BBB/DDD/CCC/EEE/ancestor::*") } } func BenchmarkDisableSelectorCache(b *testing.B) { DisableSelectorCache = true for i := 0; i < b.N; i++ { - getQuery("/AAA/BBB/DDD/CCC/EEE/ancestor::*") + xpcache.GetQuery("/AAA/BBB/DDD/CCC/EEE/ancestor::*") } } func TestSelectorCache(t *testing.T) { + xpcache = NewXpathQueryLookup(2) SelectorCacheMaxEntries = 2 for i := 1; i <= 3; i++ { - getQuery(fmt.Sprintf("//a[position()=%d]", i)) + xpcache.GetQuery(fmt.Sprintf("//a[position()=%d]", i)) } - getQuery("//a[position()=3]") + xpcache.GetQuery("//a[position()=3]") }