Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

[NeuralChat] RAG evaluation #1333

Open
wants to merge 158 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
158 commits
Select commit Hold shift + click to select a range
f820019
add retrieval dataset construction codes
Liangyx2 Mar 1, 2024
06f8162
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
5ef0332
Update llm_generate_raw_data.py
Liangyx2 Mar 1, 2024
ee1db83
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
89597f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
b132d66
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
8e955ce
update
Liangyx2 Mar 1, 2024
635b906
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 1, 2024
d7d3d03
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
c9fec02
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
5e32113
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
f67622c
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
f2e344a
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 Mar 1, 2024
383e5b3
Update prompt.py
Liangyx2 Mar 4, 2024
81014d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
4b7bec7
Update llm_generate_raw_data.py
Liangyx2 Mar 4, 2024
0df51a6
Update llm_generate_raw_data.py
Liangyx2 Mar 4, 2024
95b16bd
Update retrieval_dataset_construction.py
Liangyx2 Mar 4, 2024
80dd21b
Update llm_generate_raw_data.py
Liangyx2 Mar 4, 2024
f495b22
Update mine_hard_negatives_check_similarity.py
Liangyx2 Mar 4, 2024
593dee3
add test_evaluation.py to nightly test
Liangyx2 Mar 4, 2024
cf59b18
Update and rename requirements.txt to requirements_cpu.txt
Liangyx2 Mar 4, 2024
40e0b0e
Create requirements_cuda.txt
Liangyx2 Mar 4, 2024
bf1b1aa
Update requirements.txt
Liangyx2 Mar 4, 2024
5552ebc
Update retrieval_dataset_construction.py
Liangyx2 Mar 4, 2024
d3b7579
Update llm_generate_raw_data.py
Liangyx2 Mar 4, 2024
f500b2b
Update retrieval_dataset_construction.py
Liangyx2 Mar 4, 2024
b65c4bf
Update llm_generate_raw_data.py
Liangyx2 Mar 4, 2024
c43ab73
Update test_evaluation.py
Liangyx2 Mar 4, 2024
feda3c0
Update retrieval_dataset_construction.py
Liangyx2 Mar 4, 2024
1c2c22c
Update mine_hard_negatives_check_similarity.py
Liangyx2 Mar 4, 2024
55a5cda
add README.md
Liangyx2 Mar 6, 2024
7a74f86
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 6, 2024
39754d0
Update README.md
Liangyx2 Mar 7, 2024
d7e95f0
add evaluate_retrieval.py
Liangyx2 Mar 8, 2024
186ab43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 8, 2024
1496219
Update test_evaluation.py
Liangyx2 Mar 11, 2024
03a768e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 11, 2024
128d587
Update test_evaluation.py
Liangyx2 Mar 11, 2024
25177bd
Merge branch 'main' into yuxiang/evaluation
XuehaoSun Mar 11, 2024
705752a
add README.md
Liangyx2 Mar 11, 2024
675fe2e
Update prompt.py
Liangyx2 Mar 12, 2024
988e542
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
d0c3c34
add llm_generate_truth.py and data
Liangyx2 Mar 12, 2024
be1106b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
48788d4
add ragas_evaluation.py
Liangyx2 Mar 12, 2024
54cc6c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
e1b5585
Create requirements.txt
Liangyx2 Mar 12, 2024
88a4293
Update llm_generate_truth.py
Liangyx2 Mar 12, 2024
83060f9
Update evaluate_retrieval.py
Liangyx2 Mar 12, 2024
76b1175
Update ragas_evaluation.py
Liangyx2 Mar 12, 2024
b775095
Update test_evaluation.py
Liangyx2 Mar 12, 2024
edbb32c
Update llm_generate_truth.py
Liangyx2 Mar 12, 2024
8962abf
Update README.md
Liangyx2 Mar 14, 2024
2ef4e05
Update README.md
Liangyx2 Mar 14, 2024
d2ab7d8
add README.md
Liangyx2 Mar 14, 2024
bcdf209
Update README.md
Liangyx2 Mar 14, 2024
102649b
Update README.md
Liangyx2 Mar 14, 2024
36a28a4
Update README.md
Liangyx2 Mar 14, 2024
548fdd9
Add files via upload
Liangyx2 Mar 15, 2024
36448ea
Delete intel_extension_for_transformers/neural_chat/tests/ci/tools/te…
Liangyx2 Mar 15, 2024
26e3e9d
Update requirements.txt
Liangyx2 Mar 15, 2024
e4793d3
Update README.md
Liangyx2 Mar 15, 2024
0569b54
Update hn_mine.py
Liangyx2 Mar 15, 2024
2d15ec0
Update README.md
Liangyx2 Mar 15, 2024
e8127e9
Update ragas_evaluation.py
Liangyx2 Mar 18, 2024
321e9b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
f9b4dab
Update requirements.txt
Liangyx2 Mar 18, 2024
76dc219
Update README.md
Liangyx2 Mar 18, 2024
b9db553
Update README.md
Liangyx2 Mar 18, 2024
d7b68cb
Update README.md
Liangyx2 Mar 18, 2024
48de606
Update requirements.txt
Liangyx2 Mar 18, 2024
415ebc8
Update ragas_evaluation.py
Liangyx2 Mar 18, 2024
f03badd
Update test_evaluation.py
Liangyx2 Mar 18, 2024
2b92e74
Update README.md
Liangyx2 Mar 18, 2024
9091729
Update retrieval_dataset_construction.py
Liangyx2 Mar 18, 2024
be32736
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
2c4f452
Update hn_mine.py
Liangyx2 Mar 18, 2024
c48f66a
Update llm_generate_raw_data.py
Liangyx2 Mar 18, 2024
654c44a
Update mine_hard_negatives_check_similarity.py
Liangyx2 Mar 18, 2024
5208c98
Update hn_mine.py
Liangyx2 Mar 18, 2024
ace1090
Update test_evaluation.py
Liangyx2 Mar 18, 2024
83f10e9
Update ragas_evaluation.py
Liangyx2 Mar 18, 2024
ac0aef1
Update README.md
Liangyx2 Mar 18, 2024
8deaabd
Update README.md
Liangyx2 Mar 19, 2024
2eb084c
Update README.md
Liangyx2 Mar 19, 2024
510e801
Update README.md
Liangyx2 Mar 19, 2024
dd1f37c
Update README.md
Liangyx2 Mar 19, 2024
ed95d2d
Update prompt.py
Liangyx2 Mar 19, 2024
e253f41
Update ragas_evaluation.py
Liangyx2 Mar 19, 2024
fc0b6b9
add evaluate_retrieval_auto.py
Liangyx2 Mar 20, 2024
6f081b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
746adec
Update evaluate_retrieval_auto.py
Liangyx2 Mar 21, 2024
100322e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2024
5e07789
Update evaluate_retrieval.py
Liangyx2 Mar 21, 2024
0a2f742
Update ragas_evaluation.py
Liangyx2 Mar 21, 2024
1752684
Update test_evaluation.py
Liangyx2 Mar 21, 2024
2a2238e
Update ragas_evaluation.py
Liangyx2 Mar 22, 2024
e8f0f9c
Update README.md
Liangyx2 Mar 22, 2024
8d65078
Update and rename evaluate_retrieval_auto.py to evaluate_retrieval_be…
Liangyx2 Mar 22, 2024
a951a89
Update evaluate_retrieval_benchmark.py
Liangyx2 Mar 25, 2024
13921f6
add retrieval_benchmark.py
Liangyx2 Mar 25, 2024
02c0813
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 25, 2024
d212d66
Update retrieval_benchmark.py
Liangyx2 Mar 25, 2024
20529a4
add ragas_benchmark ragas_evaluation_benchmark
Liangyx2 Mar 26, 2024
5026421
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
cfa7d9c
Update retrieval_benchmark.py
Liangyx2 Mar 26, 2024
8d1215e
Update evaluate_retrieval_benchmark.py
Liangyx2 Mar 26, 2024
3458a8e
Update retrieval_benchmark.py
Liangyx2 Mar 26, 2024
4effd37
Update ragas_evaluation_benchmark.py
Liangyx2 Mar 26, 2024
3c38ae6
Update ragas_benchmark.py
Liangyx2 Mar 26, 2024
b02da07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
a2a7de1
Update ragas_evaluation_benchmark.py
Liangyx2 Mar 26, 2024
4191f4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
35b2d7d
Update evaluate_retrieval_benchmark.py
Liangyx2 Mar 27, 2024
56037b9
Update ragas_evaluation_benchmark.py
Liangyx2 Mar 27, 2024
de44f0d
add retrieval_benchmark.sh
Liangyx2 Mar 27, 2024
67456e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2024
2a91336
add ragas_benchmark.sh
Liangyx2 Mar 27, 2024
8f05a34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 27, 2024
c64ca3c
add data.txt
Liangyx2 Mar 27, 2024
fbef1f6
Update ragas_benchmark.sh
Liangyx2 Mar 27, 2024
f50aeb4
Update ragas_evaluation_benchmark.py
Liangyx2 Mar 28, 2024
84aea7c
Update ragas_benchmark.sh
Liangyx2 Mar 28, 2024
ad1814a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
932562d
Update and rename ragas_benchmark.py to ragas_superbenchmark.py
Liangyx2 Mar 28, 2024
50d8c83
Update evaluate_retrieval_benchmark.py
Liangyx2 Mar 28, 2024
a4ea5dd
Update retrieval_benchmark.sh
Liangyx2 Mar 28, 2024
6e29d43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
702f9a9
Update and rename retrieval_benchmark.py to retrieval_superbenchmark.py
Liangyx2 Mar 28, 2024
0452526
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
008a892
add README.md
Liangyx2 Mar 28, 2024
5303837
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
8957b18
Update README.md
Liangyx2 Mar 28, 2024
96f477c
Update README.md
Liangyx2 Mar 29, 2024
c99856d
Update README.md
Liangyx2 Mar 29, 2024
19dfb93
Update README.md
Liangyx2 Apr 1, 2024
99940f3
Update README.md
Liangyx2 Apr 1, 2024
464d52b
Update README.md
Liangyx2 Apr 1, 2024
da2e829
Update README.md
Liangyx2 Apr 1, 2024
3ce2cb2
Update README.md
Liangyx2 Apr 1, 2024
268d89c
Update README.md
Liangyx2 Apr 1, 2024
40fc2e9
Update README.md
Liangyx2 Apr 1, 2024
13bb3b8
Update README.md
Liangyx2 Apr 1, 2024
763bd1d
add config file form rag evaluation
xmx-521 Apr 10, 2024
092e951
complete config superbenchmark
xmx-521 Apr 15, 2024
e931143
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
f0a0cd6
Merge branch 'main' into yuxiang/evaluation
XuhuiRen May 8, 2024
895075b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 8, 2024
6b60154
Create test_evaluation.py in CI
Liangyx2 May 10, 2024
c73a68f
Update requirements.txt
Liangyx2 May 11, 2024
c6f8906
Merge branch 'main' into yuxiang/evaluation
Liangyx2 May 11, 2024
7c80ce2
Merge branch 'main' into yuxiang/evaluation
VincyZhang May 13, 2024
576ce57
Merge branch 'main' into yuxiang/evaluation
Liangyx2 May 14, 2024
2a3ddd9
Merge branch 'main' into yuxiang/evaluation
Liangyx2 May 15, 2024
b4c0e67
Update ragas_evaluation_benchmark.py
Liangyx2 Jun 3, 2024
e75bbe4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
a0853a8
Merge branch 'main' into yuxiang/evaluation
Liangyx2 Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unicodedata
import pandas as pd
import re, json
from langchain.document_loaders import UnstructuredMarkdownLoader
from docx import Document as DDocument
from bs4 import BeautifulSoup
import fitz
import easyocr
from PIL import Image
import numpy as np
import io

def uni_pro(text):
"""Check if the character is ASCII or falls in the category of non-spacing marks."""
normalized_text = unicodedata.normalize('NFKD', text)
filtered_text = ''
for char in normalized_text:
if ord(char) < 128 or unicodedata.category(char) == 'Mn':
filtered_text += char
return filtered_text


def read_pdf(pdf_path):
Liangyx2 marked this conversation as resolved.
Show resolved Hide resolved
"""Read the pdf file."""
doc = fitz.open(pdf_path)
reader = easyocr.Reader(['en'])
result =''
for i in range(doc.page_count):
page = doc.load_page(i)
pagetext = page.get_text().strip()
if pagetext:
if pagetext.endswith('!') or pagetext.endswith('?') or pagetext.endswith('.'):
result=result+pagetext
else:
result=result+pagetext+'.'
if len(doc.get_page_images(i)) > 0 :
for img in doc.get_page_images(i):
if img:
pageimg=''
xref = img[0]
img_data = doc.extract_image(xref)
img_bytes = img_data['image']
pil_image = Image.open(io.BytesIO(img_bytes))
img = np.array(pil_image)
img_result = reader.readtext(img, paragraph=True, detail=0)
pageimg=pageimg + ', '.join(img_result).strip()
if pageimg.endswith('!') or pageimg.endswith('?') or pageimg.endswith('.'):
pass
else:
pageimg=pageimg+'.'
result=result+pageimg
return result


def read_html(html_path):
"""Read the html file."""
with open(html_path, 'r', encoding="utf-8") as file:
html = file.read()
soup = BeautifulSoup(html, 'html.parser')
text = soup.get_text(strip=True)
return text


def read_txt(txt_path):
"""Read txt file."""
with open(txt_path, 'r') as file:
text = file.read()
return text


def read_docx(doc_path):
"""Read docx file."""
doc = DDocument(doc_path)
text = ''
for paragraph in doc.paragraphs:
text += paragraph.text
return text


def read_md(md_path):
"""Read docx file."""
loader = UnstructuredMarkdownLoader(md_path)
text = loader.load()[0].page_content
return text


def load_json(input, process, max_length, min_length):
"""Load and process json file."""
data = []
with open(input, 'r') as file:
for line in file:
json_obj = json.loads(line)
data.append(json_obj)

new_sens = []
new_collect = []
for sub in data:
sub['content'].replace('#', " ")
sub['content'] = re.sub(r'\s+', ' ', sub['content'])
if not process:
if len(sub['content']) < min_length:
continue
new_doc = [sub['content'], sub['link']]
new_collect.append(new_doc)
else:
for sub in data:
sub['content'].replace('#', " ")
if len(sub['content'])<min_length:
continue
split_sen = re.split(r'[.?!]', sub['content'])
for num in range(len(split_sen)):
split_sen[num] = re.sub(r'\s+', ' ', split_sen[num])
if num +1 < len(split_sen):
if len(split_sen[num]) >max_length:
new_sens.append(split_sen[num].strip())
else:
split_sen[num +1] =split_sen[num] +split_sen[num+1]
else:
new_sens.append(split_sen[num])

paragraphs = list(set(new_sens))
for paragraph in paragraphs:
new_doc = [paragraph, sub['link']]
new_collect.append(new_doc)
return new_collect


def load_xlsx(input):
"""Load and process xlsx file."""
df = pd.read_excel(input)
header = df.columns.tolist()
all_data = []
if 'Questions' in header and 'Answers' in header:
for index, row in df.iterrows():
sub = row["Answers"]
sub=sub.replace('#', " ")
sub = sub.replace(r'\t', " ")
sub = sub.replace('\n', ' ')
sub = sub.replace('\n\n', ' ')
sub = re.sub(r'\s+', ' ', sub)
new_doc = [sub, input]
all_data.append(new_doc)
elif 'question' in header and 'answer' in header and 'link' in header:
for index, row in df.iterrows():
sub = row["answer"]
sub = sub.replace('#', " ")
sub = sub.replace(r'\t', " ")
sub = sub.replace('\n', ' ')
sub = sub.replace('\n\n', ' ')
sub = re.sub(r'\s+', ' ', sub)
all_data.append([sub, row['link']])
elif 'context' in header and 'link' in header:
for index, row in df.iterrows():
sub = row['context']
sub = sub.replace('#', " ")
sub = sub.replace(r'\t', " ")
sub = sub.replace('\n', ' ')
sub = sub.replace('\n\n', ' ')
sub = re.sub(r'\s+', ' ', sub)
all_data.append([sub, row['link']])
return all_data

def load_csv(input):
""" Load the csv file."""
df = pd.read_csv(input)
all_data = []
documents = []
for index, row in df.iterrows():
sub = row["correct_answer"]
all_data.append(sub)

for data in all_data:
data.replace('#', " ")
data = re.sub(r'\s+', ' ', data)
new_doc = [data, input]
documents.append(new_doc)
return documents

def load_structured_data(input, process, max_length, min_length):
"""Load structured context."""
if input.endswith("jsonl") or input.endswith("json"):
content = load_json(input, process, max_length, min_length)
elif input.endswith("xlsx"):
content = load_xlsx(input)
elif input.endswith("csv"):
content = load_csv(input)
return content

def load_unstructured_data(input):
"""Load unstructured context."""
if input.endswith("pdf"):
text = read_pdf(input)
elif input.endswith("docx"):
text = read_docx(input)
elif input.endswith("html"):
text = read_html(input)
elif input.endswith("txt"):
text = read_txt(input)
elif input.endswith("md"):
text = read_md(input)

text = text.replace('\n', ' ')
text = text.replace('\n\n', ' ')
text = uni_pro(text)
text = re.sub(r'\s+', ' ', text)
return text

def get_chuck_data(content, max_length, min_length, input):
"""Process the context to make it maintain a suitable length for the generation."""
sentences = re.split('(?<=[!.?])', content)

paragraphs = []
current_length = 0
count = 0
current_paragraph = ""
for sub_sen in sentences:
count +=1
sentence_length = len(sub_sen)
if current_length + sentence_length <= max_length:
current_paragraph += sub_sen
current_length += sentence_length
if count == len(sentences) and len(current_paragraph.strip())>min_length:
paragraphs.append([current_paragraph.strip() ,input])
else:
paragraphs.append([current_paragraph.strip() ,input])
current_paragraph = sub_sen
current_length = sentence_length

return paragraphs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import random
import numpy as np
import faiss
from tqdm import tqdm

def create_index(embeddings, use_gpu):
index = faiss.IndexFlatIP(len(embeddings[0]))
embeddings = np.asarray(embeddings, dtype=np.float32)
if use_gpu:
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
index = faiss.index_cpu_to_all_gpus(index, co=co)
index.add(embeddings)
return index

def batch_search(index,
query,
topk: int = 200,
batch_size: int = 64):
all_scores, all_inxs = [], []
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256):
batch_query = query[start_index:start_index + batch_size]
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk)
all_scores.extend(batch_scores.tolist())
all_inxs.extend(batch_inxs.tolist())
return all_scores, all_inxs

def get_corpus(candidate_pool):
corpus = []
for line in open(candidate_pool):
line = json.loads(line.strip())
corpus.append(line['text'])
return corpus

def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu):
corpus = []
queries = []
train_data = []
for line in open(input_file):
line = json.loads(line.strip())
train_data.append(line)
corpus.extend(line['pos'])
if 'neg' in line:
corpus.extend(line['neg'])
queries.append(line['query'])

if candidate_pool is not None:
if not isinstance(candidate_pool, list):
candidate_pool = get_corpus(candidate_pool)
corpus = list(set(candidate_pool))
else:
corpus = list(set(corpus))

p_vecs = model.encode(corpus, batch_size=256)
q_vecs = model.encode(queries, batch_size=256)

index = create_index(p_vecs, use_gpu=use_gpu)
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1])
assert len(all_inxs) == len(train_data)

for i, data in enumerate(train_data):
query = data['query']
inxs = all_inxs[i][sample_range[0]:sample_range[1]]
filtered_inx = []
for inx in inxs:
if inx == -1: break
if corpus[inx] not in data['pos'] and corpus[inx] != query:
filtered_inx.append(inx)

if len(filtered_inx) > negative_number:
filtered_inx = random.sample(filtered_inx, negative_number)
data['neg'] = [corpus[inx] for inx in filtered_inx]

with open(output_file, 'w') as f:
for data in train_data:
if len(data['neg']) < negative_number:
data['neg'].extend(random.sample(corpus, negative_number - len(data['neg'])))
f.write(json.dumps(data, ensure_ascii=False) + '\n')
Loading
Loading