diff --git a/rag/collection.go b/rag/collection.go index cb4f651..87fa2e8 100644 --- a/rag/collection.go +++ b/rag/collection.go @@ -9,6 +9,7 @@ import ( "github.com/mudler/localrecall/pkg/xlog" "github.com/mudler/localrecall/rag/engine" "github.com/mudler/localrecall/rag/engine/localai" + "github.com/mudler/localrecall/rag/types" "github.com/sashabaranov/go-openai" ) @@ -22,10 +23,17 @@ func NewPersistentChromeCollection(llmClient *openai.Client, collectionName, dbP os.Exit(1) } + // Create a hybrid search engine with the ChromemDB engine + hybridEngine, err := engine.NewHybridSearchEngine(chromemDB, types.NewBasicReranker(), dbPath) + if err != nil { + xlog.Error("Failed to create hybrid search engine", err) + os.Exit(1) + } + persistentKB, err := NewPersistentCollectionKB( filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)), filePath, - chromemDB, + hybridEngine, maxChunkSize) if err != nil { xlog.Error("Failed to create PersistentKB", err) @@ -40,10 +48,17 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co laiStore := localai.NewStoreClient(apiURL, apiKey) ragDB := engine.NewLocalAIRAGDB(laiStore, llmClient, embeddingModel) + // Create a hybrid search engine with the LocalAI engine + hybridEngine, err := engine.NewHybridSearchEngine(ragDB, types.NewBasicReranker(), dbPath) + if err != nil { + xlog.Error("Failed to create hybrid search engine", err) + os.Exit(1) + } + persistentKB, err := NewPersistentCollectionKB( filepath.Join(dbPath, fmt.Sprintf("%s%s.json", collectionPrefix, collectionName)), filePath, - ragDB, + hybridEngine, maxChunkSize) if err != nil { xlog.Error("Failed to create PersistentKB", err) @@ -59,18 +74,15 @@ func NewPersistentLocalAICollection(llmClient *openai.Client, apiURL, apiKey, co // ListAllCollections lists all collections in the database func ListAllCollections(dbPath string) []string { + collections := []string{} files, err := os.ReadDir(dbPath) if err != nil { - xlog.Error("Failed to read directory", err) - return nil + return collections } - var collections []string - for _, file := range files { - if !file.IsDir() && filepath.Ext(file.Name()) == ".json" && strings.HasPrefix(file.Name(), collectionPrefix) { - collectionName := strings.TrimPrefix(file.Name(), collectionPrefix) - collectionName = strings.TrimSuffix(collectionName, ".json") - collections = append(collections, collectionName) + for _, f := range files { + if strings.HasPrefix(f.Name(), collectionPrefix) { + collections = append(collections, strings.TrimPrefix(strings.TrimSuffix(f.Name(), ".json"), collectionPrefix)) } } diff --git a/rag/engine.go b/rag/engine.go index d88df32..a95d176 100644 --- a/rag/engine.go +++ b/rag/engine.go @@ -1,12 +1,12 @@ package rag import ( + "github.com/mudler/localrecall/rag/interfaces" "github.com/mudler/localrecall/rag/types" ) -type Engine interface { - Store(s string, meta map[string]string) error - Reset() error - Search(s string, similarEntries int) ([]types.Result, error) - Count() int -} +// Engine is an alias for interfaces.Engine +type Engine = interfaces.Engine + +// Result is an alias for types.Result +type Result = types.Result diff --git a/rag/engine/fulltext.go b/rag/engine/fulltext.go new file mode 100644 index 0000000..6b2a156 --- /dev/null +++ b/rag/engine/fulltext.go @@ -0,0 +1,141 @@ +package engine + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/mudler/localrecall/rag/types" +) + +// FullTextIndex manages the full-text search index +type FullTextIndex struct { + path string + documents map[string]string + mu sync.RWMutex +} + +// NewFullTextIndex creates a new full-text index +func NewFullTextIndex(path string) (*FullTextIndex, error) { + index := &FullTextIndex{ + path: path, + documents: make(map[string]string), + } + + // Load existing index if it exists + if err := index.load(); err != nil { + return nil, fmt.Errorf("failed to load full-text index: %w", err) + } + + return index, nil +} + +// Store adds a document to the index +func (i *FullTextIndex) Store(id string, content string) error { + i.mu.Lock() + defer i.mu.Unlock() + + i.documents[id] = content + return i.save() +} + +// Delete removes a document from the index +func (i *FullTextIndex) Delete(id string) error { + i.mu.Lock() + defer i.mu.Unlock() + + delete(i.documents, id) + return i.save() +} + +// Reset clears the index +func (i *FullTextIndex) Reset() error { + i.mu.Lock() + defer i.mu.Unlock() + + i.documents = make(map[string]string) + return i.save() +} + +// Search performs full-text search on the index +func (i *FullTextIndex) Search(query string, maxResults int) []types.Result { + i.mu.RLock() + defer i.mu.RUnlock() + + queryTerms := strings.Fields(strings.ToLower(query)) + scoredResults := make([]types.Result, 0) + + // Score all documents + for id, content := range i.documents { + contentLower := strings.ToLower(content) + score := float32(0) + + // Simple term frequency scoring + for _, term := range queryTerms { + if strings.Contains(contentLower, term) { + score += 1.0 + } + } + + // Normalize score + if len(queryTerms) > 0 { + score = score / float32(len(queryTerms)) + } + + // Only include documents with a score > 0 + if score > 0 { + scoredResults = append(scoredResults, types.Result{ + ID: id, + Content: content, + FullTextScore: score, + }) + } + } + + // Sort by full-text score + for i := 0; i < len(scoredResults); i++ { + for j := i + 1; j < len(scoredResults); j++ { + if scoredResults[i].FullTextScore < scoredResults[j].FullTextScore { + scoredResults[i], scoredResults[j] = scoredResults[j], scoredResults[i] + } + } + } + + // Return top maxResults results + if len(scoredResults) > maxResults { + scoredResults = scoredResults[:maxResults] + } + + return scoredResults +} + +// load reads the index from disk +func (i *FullTextIndex) load() error { + data, err := os.ReadFile(i.path) + if err != nil { + if os.IsNotExist(err) { + return nil // File doesn't exist yet, that's okay + } + return err + } + + return json.Unmarshal(data, &i.documents) +} + +// save writes the index to disk +func (i *FullTextIndex) save() error { + data, err := json.Marshal(i.documents) + if err != nil { + return err + } + + // Ensure directory exists + if err := os.MkdirAll(filepath.Dir(i.path), 0755); err != nil { + return err + } + + return os.WriteFile(i.path, data, 0644) +} diff --git a/rag/engine/hybrid.go b/rag/engine/hybrid.go new file mode 100644 index 0000000..ed5b15a --- /dev/null +++ b/rag/engine/hybrid.go @@ -0,0 +1,121 @@ +package engine + +import ( + "fmt" + "path/filepath" + + "github.com/mudler/localrecall/rag/interfaces" + "github.com/mudler/localrecall/rag/types" +) + +// HybridSearchEngine combines semantic and full-text search +type HybridSearchEngine struct { + semanticEngine interfaces.Engine + reranker types.Reranker + fullTextIndex *FullTextIndex +} + +// NewHybridSearchEngine creates a new hybrid search engine +func NewHybridSearchEngine(semanticEngine interfaces.Engine, reranker types.Reranker, dbPath string) (*HybridSearchEngine, error) { + // Create full-text index in the same directory as the semantic engine + fullTextIndex, err := NewFullTextIndex(filepath.Join(dbPath, "fulltext.json")) + if err != nil { + return nil, fmt.Errorf("failed to create full-text index: %w", err) + } + + return &HybridSearchEngine{ + semanticEngine: semanticEngine, + reranker: reranker, + fullTextIndex: fullTextIndex, + }, nil +} + +// Store stores a document in both semantic and full-text indexes +func (h *HybridSearchEngine) Store(s string, metadata map[string]string) error { + // Store in semantic engine + if err := h.semanticEngine.Store(s, metadata); err != nil { + return err + } + + // Store in full-text index + // Use the content as the ID since we don't have a better identifier + return h.fullTextIndex.Store(s, s) +} + +// Reset resets both semantic and full-text indexes +func (h *HybridSearchEngine) Reset() error { + if err := h.semanticEngine.Reset(); err != nil { + return err + } + return h.fullTextIndex.Reset() +} + +// Count returns the number of documents in the index +func (h *HybridSearchEngine) Count() int { + return h.semanticEngine.Count() +} + +// Search performs hybrid search by combining semantic and full-text search results +func (h *HybridSearchEngine) Search(query string, similarEntries int) ([]types.Result, error) { + // Perform semantic search + semanticResults, err := h.semanticEngine.Search(query, similarEntries) + if err != nil { + return nil, fmt.Errorf("semantic search failed: %w", err) + } + + // Perform full-text search on all documents + fullTextResults := h.fullTextIndex.Search(query, similarEntries) + + // Combine results from both searches + combinedResults := h.combineResults(semanticResults, fullTextResults) + + // Rerank the combined results + rerankedResults, err := h.reranker.Rerank(query, combinedResults) + if err != nil { + return nil, fmt.Errorf("reranking failed: %w", err) + } + + return rerankedResults, nil +} + +// combineResults combines semantic and full-text search results +func (h *HybridSearchEngine) combineResults(semanticResults, fullTextResults []types.Result) []types.Result { + // Create a map to track unique results by content + resultMap := make(map[string]types.Result) + + // Add semantic results + for _, result := range semanticResults { + resultMap[result.Content] = result + } + + // Add full-text results, combining scores if the same content exists + for _, result := range fullTextResults { + if existing, exists := resultMap[result.Content]; exists { + // If the content exists in both results, combine the scores + existing.FullTextScore = result.FullTextScore + existing.CombinedScore = (existing.Similarity + result.FullTextScore) / 2 + resultMap[result.Content] = existing + } else { + // If it's a new result, just add it + result.CombinedScore = result.FullTextScore + resultMap[result.Content] = result + } + } + + // Convert map back to slice + combinedResults := make([]types.Result, 0, len(resultMap)) + for _, result := range resultMap { + combinedResults = append(combinedResults, result) + } + + // Sort by combined score + for i := 0; i < len(combinedResults); i++ { + for j := i + 1; j < len(combinedResults); j++ { + if combinedResults[i].CombinedScore < combinedResults[j].CombinedScore { + combinedResults[i], combinedResults[j] = combinedResults[j], combinedResults[i] + } + } + } + + return combinedResults +} diff --git a/rag/interfaces/engine.go b/rag/interfaces/engine.go new file mode 100644 index 0000000..a825af1 --- /dev/null +++ b/rag/interfaces/engine.go @@ -0,0 +1,11 @@ +package interfaces + +import "github.com/mudler/localrecall/rag/types" + +// Engine defines the interface for search engines +type Engine interface { + Store(s string, meta map[string]string) error + Reset() error + Search(s string, similarEntries int) ([]types.Result, error) + Count() int +} diff --git a/rag/types/reranker.go b/rag/types/reranker.go new file mode 100644 index 0000000..7bb3c1c --- /dev/null +++ b/rag/types/reranker.go @@ -0,0 +1,22 @@ +package types + +// Reranker defines the interface for reranking search results +type Reranker interface { + // Rerank takes a query and a list of results, and returns a reranked list + Rerank(query string, results []Result) ([]Result, error) +} + +// BasicReranker implements a simple reranking strategy that combines semantic and full-text scores +type BasicReranker struct{} + +// NewBasicReranker creates a new BasicReranker instance +func NewBasicReranker() *BasicReranker { + return &BasicReranker{} +} + +// Rerank implements a simple reranking strategy that combines semantic and full-text scores +func (r *BasicReranker) Rerank(query string, results []Result) ([]Result, error) { + // For now, we'll just return the results as is + // In a real implementation, we would combine semantic and full-text scores + return results, nil +} diff --git a/rag/types/result.go b/rag/types/result.go index 710187e..c1acbc6 100644 --- a/rag/types/result.go +++ b/rag/types/result.go @@ -11,4 +11,12 @@ type Result struct { // The higher the value, the more similar the document is to the query. // The value is in the range [-1, 1]. Similarity float32 + + // FullTextScore represents the score from full-text search + // The higher the value, the more relevant the document is to the query. + FullTextScore float32 + + // CombinedScore represents the final score after reranking + // This is calculated by the reranker + CombinedScore float32 }