Skip to content

Commit

Permalink
Merge pull request #107 from StampyAI/more-tests
Browse files Browse the repository at this point in the history
More tests
  • Loading branch information
mruwnik authored Sep 30, 2023
2 parents 1c91278 + 4bd60b2 commit d091bdc
Show file tree
Hide file tree
Showing 6 changed files with 445 additions and 71 deletions.
30 changes: 15 additions & 15 deletions api/src/stampy_chat/followups.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@ class Followup:
# https://nlp.stampy.ai/api/search?query=what%20is%20agi

def search_authored(query: str):
multisearch_authored([query])
return multisearch_authored([query])

# search with multiple queries, combine results
def multisearch_authored(queries: List[str]):

followups = {}

for query in queries:
def get_followups(query):
url = 'https://nlp.stampy.ai/api/search?query=' + quote(query)
response = requests.get(url).json()
return [Followup(entry['title'], entry['pageid'], entry['score']) for entry in response]

url = 'https://nlp.stampy.ai/api/search?query=' + quote(query)
response = requests.get(url).json()
for entry in response:
followups[entry['pageid']] = Followup(entry['title'], entry['pageid'], entry['score'])

followups = list(followups.values())
# search with multiple queries, combine results
def multisearch_authored(queries: List[str]):
# sort the followups from lowest to highest score
followups = [entry for query in queries for entry in get_followups(query)]
followups = sorted(followups, key=lambda entry: entry.score)

followups.sort(key=lambda f: f.score, reverse=True)
# Remove any duplicates by making a map from the pageids. This should result in highest scored entry being used
followups = {entry.pageid: entry for entry in followups if entry.score > SIMILARITY_THRESHOLD}

followups = followups[:MAX_FOLLOWUPS]
# Get the first `MAX_FOLLOWUPS`
followups = sorted(followups.values(), reverse=True, key=lambda e: e.score)
followups = list(followups)[:MAX_FOLLOWUPS]

if logger.is_debug():
logger.debug(" ------------------------------ suggested followups: -----------------------------")
Expand All @@ -50,6 +52,4 @@ def multisearch_authored(queries: List[str]):
logger.debug(followup.pageid)
logger.debug('')

followups = [ f for f in followups if f.score > SIMILARITY_THRESHOLD ]

return followups
105 changes: 50 additions & 55 deletions api/src/stampy_chat/get_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import regex as re
import requests
import time
from typing import List, Tuple
from itertools import groupby
from typing import Iterable, List, Tuple
from stampy_chat.env import PINECONE_NAMESPACE, REMOTE_CHAT_INSTANCE, EMBEDDING_MODEL
from stampy_chat import logging

Expand Down Expand Up @@ -48,6 +49,51 @@ def get_embedding(text: str) -> np.ndarray:
time.sleep(min(max_wait_time, 2 ** attempt))


def parse_block(match) -> Block:
metadata = match['metadata']

date = metadata.get('date_published') or metadata.get('date')

if isinstance(date, datetime.date):
date = date.isoformat()
elif isinstance(date, datetime.datetime):
date = date.date().isoformat()
elif isinstance(date, (int, float)):
date = datetime.datetime.fromtimestamp(date).isoformat()

authors = metadata.get('authors')
if not authors and metadata.get('author'):
authors = [metadata.get('author')]

return Block(
id = metadata.get('hash_id') or metadata.get('id'),
title = metadata['title'],
authors = authors,
date = date,
url = metadata['url'],
tags = metadata.get('tags'),
text = strip_block(metadata['text'])
)


def join_blocks(blocks: Iterable[Block]) -> List[Block]:
# for all blocks that are "the same" (same title, author, date, url, tags),
# combine their text with "....." in between. Return them in order such
# that the combined block has the minimum index of the blocks combined.

def to_tuple(block):
return (block.id, block.title or "", block.authors or [], block.date or "", block.url or "", block.tags or "")

def merge_texts(blocks):
return "\n.....\n".join(sorted(block.text for block in blocks))

unified_blocks = [
Block(*key, merge_texts(group))
for key, group in groupby(blocks, key=to_tuple)
]
return sorted(unified_blocks, key=to_tuple)


# Get the k blocks most semantically similar to the query using Pinecone.
def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:

Expand All @@ -69,7 +115,7 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
}
)

return [Block(**block) for block in response.json()]
return [parse_block({'metadata': block}) for block in response.json()]

# print time
t = time.time()
Expand All @@ -87,63 +133,12 @@ def get_top_k_blocks(index, user_query: str, k: int) -> List[Block]:
include_metadata=True,
vector=query_embedding
)
blocks = []
for match in query_response['matches']:
metadata = match['metadata']

date = metadata.get('date_published') or metadata.get('date')

if isinstance(date, datetime.date):
date = date.isoformat()
elif isinstance(date, datetime.datetime):
date = date.date().isoformat()
elif isinstance(date, float):
date = datetime.datetime.fromtimestamp(date).date().isoformat()

authors = metadata.get('authors')
if not authors and metadata.get('author'):
authors = [metadata.get('author')]

blocks.append(Block(
id = metadata.get('hash_id'),
title = metadata['title'],
authors = authors,
date = date,
url = metadata['url'],
tags = metadata.get('tags'),
text = strip_block(metadata['text'])
))

blocks = [parse_block(match) for match in query_response['matches']]
t2 = time.time()

logger.debug(f'Time to get top-k blocks: {t2-t1:.2f}s')

# for all blocks that are "the same" (same title, author, date, url, tags),
# combine their text with "....." in between. Return them in order such
# that the combined block has the minimum index of the blocks combined.

key = lambda bi: (bi[0].id, bi[0].title or "", bi[0].authors or [], bi[0].date or "", bi[0].url or "", bi[0].tags or "")

blocks_plus_old_index = [(block, i) for i, block in enumerate(blocks)]
blocks_plus_old_index.sort(key=key)

unified_blocks: List[Tuple[Block, int]] = []

for key, group in itertools.groupby(blocks_plus_old_index, key=key):
group = list(group)
if not group:
continue

# group = group[:3] # limit to a max of 3 blocks from any one source

text = "\n.....\n".join([block[0].text for block in group])

min_index = min([block[1] for block in group])

unified_blocks.append((Block(*key, text), min_index))

unified_blocks.sort(key=lambda bi: bi[1])
return [block for block, _ in unified_blocks]
return join_blocks(blocks)


# we add the title and authors inside the contents of the block, so that
Expand Down
2 changes: 1 addition & 1 deletion api/src/stampy_chat/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(self, *args, **kwargs):
def is_debug(self):
return self.isEnabledFor(DEBUG)

def interaction(self, session_id, query, response, history, prompt, blocks):
def interaction(self, session_id: str, query: str, response: str, history, prompt, blocks):
prompt = [i for i in prompt if i.get('role') == 'system']
prompt = prompt[0].get('content') if prompt else None

Expand Down
67 changes: 67 additions & 0 deletions api/tests/stampy_chat/test_followups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
from unittest.mock import patch, Mock

from stampy_chat.followups import Followup, search_authored, multisearch_authored


@pytest.mark.parametrize("query, expected_result", [
("what is agi", [Followup("agi title", "agi", 0.5)],),
("what is ai", [Followup("ai title", "ai", 0.5)],)])
def test_search_authored(query, expected_result):
response = Mock(json=lambda: [
{'title': r.text, 'pageid': r.pageid, 'score': r.score}
for r in expected_result
])

with patch('requests.get', return_value=response):
assert search_authored(query) == expected_result


@patch('stampy_chat.followups.logger')
def test_multisearch_authored(_logger):
results = [
{'pageid': '1', 'title': f'result 1', 'score': 0.423},
{'pageid': '2', 'title': f'result 2', 'score': 0.623},
{'pageid': '3', 'title': f'this should be skipped', 'score': 0.323},
{'pageid': '4', 'title': f'this should also be skipped', 'score': 0.1},
{'pageid': '5', 'title': f'result 5', 'score': 0.543},
]

response = Mock(json=lambda: results)
with patch('requests.get', return_value=response):
assert multisearch_authored(["what is this?", "how about this?"]) == [
Followup('result 2', '2', 0.623),
Followup('result 5', '5', 0.543),
Followup('result 1', '1', 0.423),
]


@patch('stampy_chat.followups.logger')
def test_multisearch_authored_duplicates(_logger):
results = {
'query1': [
{'pageid': '1', 'title': f'result 1', 'score': 0.423},
{'pageid': '2', 'title': f'result 2', 'score': 0.623},
{'pageid': '3', 'title': f'this should be skipped', 'score': 0.323},
{'pageid': '4', 'title': f'this should also be skipped', 'score': 0.1},
{'pageid': '5', 'title': f'result 5', 'score': 0.543},
],
'query2': [
{'pageid': '1', 'title': f'result 1', 'score': 0.723},
{'pageid': '2', 'title': f'this should be skipped', 'score': 0.323},
{'pageid': '5', 'title': f'this should also be skipped', 'score': 0.1},
],
'query3': [
{'pageid': '5', 'title': f'result 5', 'score': 0.511},
],
}
def getter(url):
query = url.split('query=')[-1]
return Mock(json=lambda: results[query])

with patch('requests.get', getter):
assert multisearch_authored(["query1", "query2", "query3"]) == [
Followup('result 1', '1', 0.723),
Followup('result 2', '2', 0.623),
Followup('result 5', '5', 0.543),
]
Loading

0 comments on commit d091bdc

Please sign in to comment.