Skip to content

Commit d6a80ad

Browse files
committed
🐛 addressing greptile feedback
1 parent 51b1f52 commit d6a80ad

File tree

1 file changed

+3
-5
lines changed
  • nemoguardrails/library/embedding_topic_detector

1 file changed

+3
-5
lines changed

nemoguardrails/library/embedding_topic_detector/actions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@ async def detect(self, query: str) -> Dict:
5151
[
5252
(
5353
cat,
54-
float(
55-
np.dot(query_emb, emb)
56-
/ (np.linalg.norm(query_emb) * np.linalg.norm(emb) or 1)
57-
),
54+
np.dot(query_emb, emb)
55+
/ ((np.linalg.norm(query_emb) * np.linalg.norm(emb)) or 1e-10),
5856
)
5957
for cat, embs in self.embeddings.items()
6058
for emb in embs
@@ -79,7 +77,7 @@ async def detect(self, query: str) -> Dict:
7977

8078
async def _check(context: Optional[dict], llm_task_manager, message_key: str) -> dict:
8179
config = llm_task_manager.config.rails.config.embedding_topic_detector
82-
cache_key = f"{config['embedding_model']}_{config['embedding_engine']}_{config.get('threshold', 0.75)}"
80+
cache_key = f"{config['embedding_model']}_{config['embedding_engine']}_{config.get('threshold', 0.75)}_{config.get('top_k', 3)}"
8381

8482
if cache_key not in _detector_cache:
8583
_detector_cache[cache_key] = EmbeddingTopicDetector(

0 commit comments

Comments
 (0)