From 3ade108cb40db934ca4c3faa3e5a815deea8bfa2 Mon Sep 17 00:00:00 2001 From: zhiheng huang Date: Fri, 6 Dec 2024 10:50:38 -0800 Subject: [PATCH] Legal aggregations --- denser_retriever/embeddings.py | 21 ++++ denser_retriever/keyword.py | 126 ++++++++++++---------- denser_retriever/retriever.py | 36 +++---- denser_retriever/utils.py | 9 +- denser_retriever/vectordb/milvus.py | 7 +- examples/denser_search.py | 3 +- examples/denser_search_cpws.py | 3 +- experiments/train_and_test.py | 161 +++++++++++++++++++++++----- experiments/utils.py | 4 +- pyproject.toml | 7 +- tests/test_keyword.py | 3 - tests/test_retriever.py | 7 -- 12 files changed, 260 insertions(+), 127 deletions(-) diff --git a/denser_retriever/embeddings.py b/denser_retriever/embeddings.py index abc8b15..9b6def4 100644 --- a/denser_retriever/embeddings.py +++ b/denser_retriever/embeddings.py @@ -40,6 +40,27 @@ def embed_query(self, text): return embeddings +class BGEEmbeddings(DenserEmbeddings): + def __init__(self, model_name: str, embedding_size: int): + try: + from FlagEmbedding import FlagICLModel + except ImportError as exc: + raise ImportError( + "Could not import FlagEmbedding python package." + ) from exc + + self.client = FlagICLModel(model_name, + query_instruction_for_retrieval="Given a web search query, retrieve relevant passages that answer the query.", + examples_for_task=None, # set `examples_for_task=None` to use model without examples + use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation + self.embedding_size = embedding_size + + def embed_documents(self, texts): + return self.client.encode_corpus(texts) + + def embed_query(self, text): + return self.client.encode_queries(text) + class VoyageAPIEmbeddings(DenserEmbeddings): def __init__(self, api_key: str, model_name: str, embedding_size: int): try: diff --git a/denser_retriever/keyword.py b/denser_retriever/keyword.py index b6cea41..4351bb4 100644 --- a/denser_retriever/keyword.py +++ b/denser_retriever/keyword.py @@ -83,9 +83,6 @@ def retrieve( def get_index_mappings(self) -> Dict[Any, Any]: raise NotImplementedError - @abstractmethod - def get_categories(self, field: str, k: int = 10) -> List[Any]: - raise NotImplementedError @abstractmethod def delete( @@ -128,11 +125,12 @@ def __init__( self.analysis = analysis self.client = es_connection - def create_index(self, index_name: str, search_fields: List[str], **args: Any): + def create_index(self, index_name: str, search_fields: List[str], date_fields: List[str], **args: Any): # Define the index settings and mappings self.index_name = index_name self.search_fields = FieldMapper(search_fields) + self.date_fields = date_fields logger.info("ES analysis %s", self.analysis) if self.analysis == "default": @@ -243,7 +241,8 @@ def add_documents( "source": metadata.get("source"), "pid": metadata.get("pid"), } - for filter in self.search_fields.get_keys(): + # for filter in self.search_fields.get_keys(): + for filter in metadata.keys(): value = metadata.get(filter, "") if isinstance(value, list): value = [v.strip() for v in value] @@ -276,37 +275,48 @@ def add_documents( return [] def retrieve( - self, - query: str, - k: int = 100, - filter: Dict[str, Any] = {}, - ) -> List[Tuple[Document, float]]: + self, + query: str, + k: int = 100, + filter: Dict[str, Any] = {}, + aggregation: bool = False, # Aggregate metadata + ) -> Tuple[List[Tuple[Document, float]], Dict]: assert self.client.indices.exists(index=self.index_name) start_time = time.time() + + # Build the query with title and content matching and a minimum_should_match condition query_dict = { "query": { "bool": { - "should": [ + "must": [ { - "match": { - "title": { - "query": query, - "boost": 2.0, - } + "bool": { + "should": [ + { + "match": { + "title": { + "query": query, + "boost": 2.0, + } + } + }, + { + "match": { + "content": query, + } + } + ], + "minimum_should_match": 1 # Ensure at least one of the should conditions is matched } - }, - { - "match": { - "content": query, - }, - }, - ], - "must": [], + } + ] } }, "_source": True, + "aggs": {}, # This will be populated with aggregations for fields } + # Add filters if provided for field in filter: category_or_date = filter.get(field) if category_or_date: @@ -318,7 +328,7 @@ def retrieve( "gte": category_or_date[0], "lte": category_or_date[1] if len(category_or_date) > 1 - else category_or_date[0], # type: ignore + else category_or_date[0], } } } @@ -328,32 +338,59 @@ def retrieve( {"term": {field: category_or_date}} ) + # Add aggregations for each field provided in 'fields' if aggregation is True + if aggregation: + for field in self.search_fields.get_keys(): + query_dict["aggs"][f"{field}_aggregation"] = { + "terms": { + "field": f"{field}", # Use keyword type for aggregations + "size": 50 # Adjust size as needed + } + } + + # Execute search query res = self.client.search( index=self.index_name, body=query_dict, size=k, ) + + # Process search hits (documents) top_k_used = min(len(res["hits"]["hits"]), k) docs = [] for id in range(top_k_used): _source = res["hits"]["hits"][id]["_source"] doc = Document( - page_content=_source["content"], - metadata={ - "source": _source["source"], - "title": _source["title"], - "pid": _source["pid"], - }, + page_content=_source.pop("content"), + metadata=_source, ) score = res["hits"]["hits"][id]["_score"] - for field in self.search_fields.get_keys(): - if _source.get(field): - doc.metadata[field] = _source.get(field) + # import pdb; pdb.set_trace() + # for field in self.search_fields.get_keys(): + # if _source.get(field): + # doc.metadata[field] = _source.get(field) docs.append((doc, score)) + + # Process aggregations for the specified fields + aggregations = {} + for field in self.search_fields.get_keys(): + field_agg = res.get("aggregations", {}).get(f"{field}_aggregation", {}).get("buckets", []) + cat_keys = [cat['key'] for cat in field_agg] + cat_counts = [cat['doc_count'] for cat in field_agg] + if len(cat_keys) > 0: + if field in self.date_fields: + sorted_data = sorted(zip(cat_keys, cat_counts), key=lambda x: x[0], reverse=True) + sorted_keys, sorted_counts = zip(*sorted_data) + cat_keys = list(sorted_keys) + cat_counts = list(sorted_counts) + aggregations[field] = (cat_keys, cat_counts) + retrieve_time_sec = time.time() - start_time logger.info(f"Keyword retrieve time: {retrieve_time_sec:.3f} sec.") logger.info(f"Retrieved {len(docs)} documents.") - return docs + + # Return both documents and aggregation results + return docs, aggregations def get_index_mappings(self): mapping = self.client.indices.get_mapping(index=self.index_name) @@ -382,25 +419,6 @@ def extract_fields(fields_dict, parent_name=""): all_fields = extract_fields(properties) return all_fields - def get_categories(self, field: str, k: int = 10): - query = { - "size": 0, # No actual documents are needed, just the aggregation results - "aggs": { - "all_categories": { - "terms": { - "field": field, - "size": 1000, # Adjust this value based on the expected number of unique categories - } - } - }, - } - response = self.client.search(index=self.index_name, body=query) - # Extract the aggregation results - categories = response["aggregations"]["all_categories"]["buckets"] - if k > 0: - categories = categories[:k] - res = [category["key"] for category in categories] - return res def delete( self, diff --git a/denser_retriever/retriever.py b/denser_retriever/retriever.py index 4bec3d7..1e0c6fc 100644 --- a/denser_retriever/retriever.py +++ b/denser_retriever/retriever.py @@ -42,7 +42,8 @@ def __init__( gradient_boost: Optional[DenserGradientBoost], combine_mode: str = "linear", xgb_model_features: str = "es+vs+rr_n", - search_fields: List[str] = [] + search_fields: List[str] = [], + date_fields: List[str] = [], ): # config parameters self.index_name = index_name @@ -61,7 +62,7 @@ def __init__( assert embeddings self.vector_db.create_index(index_name, embeddings, search_fields) if self.keyword_search: - self.keyword_search.create_index(index_name, search_fields) + self.keyword_search.create_index(index_name, search_fields, date_fields) def ingest(self, docs: List[Document], overwrite_pid: bool = True) -> List[str]: # add pid into metadata for each document @@ -80,22 +81,23 @@ def ingest(self, docs: List[Document], overwrite_pid: bool = True) -> List[str]: return [doc.metadata["pid"] for doc in docs] def retrieve( - self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any + self, query: str, k: int = 100, filter: Dict[str, Any]= {}, aggregation: bool = False, **kwargs: Any ): logger.info(f"Retrieve query: {query} top_k: {k}") if self.combine_mode in ["linear", "rank"]: - return self._retrieve_by_linear_or_rank(query, k, filter, **kwargs) + return self._retrieve_by_linear_or_rank(query, k, filter, aggregation, **kwargs) else: - return self._retrieve_by_model(query, k, filter, **kwargs) + return self._retrieve_by_model(query, k, filter, aggregation, **kwargs) def _retrieve_by_linear_or_rank( - self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any + self, query: str, k: int = 100, filter: Dict[str, Any] = {}, aggregation: bool = False, **kwargs: Any ): passages = [] + aggregations = None if self.keyword_search: - es_docs = self.keyword_search.retrieve( - query, self.keyword_search.top_k, filter=filter, **kwargs + es_docs, aggregations = self.keyword_search.retrieve( + query, self.keyword_search.top_k, filter=filter, aggregation=aggregation, **kwargs ) es_passages = scale_results(es_docs, self.keyword_search.weight) logger.info(f"Keyword search: {len(es_passages)}") @@ -125,10 +127,10 @@ def _retrieve_by_linear_or_rank( rerank_time_sec = time.time() - start_time logger.info(f"Rerank time: {rerank_time_sec:.3f} sec.") - return passages[:k] + return passages[:k], aggregations def _retrieve_by_model( - self, query: str, k: int = 100, filter: Dict[str, Any] = {}, **kwargs: Any + self, query: str, k: int = 100, filter: Dict[str, Any] = {}, aggregation: bool = False, **kwargs: Any ) -> List[Tuple[Document, float]]: docs, doc_features = self._retrieve_with_features(query, filter, **kwargs) @@ -262,20 +264,6 @@ def delete_all(self): if self.keyword_search: self.keyword_search.delete_all() - def get_field_categories(self, field, k: int = 10): - """ - Get the categories of a field. - - Args: - field: The field to get the categories of. - k: The number of categories to return. - - Returns: - A list of categories. - """ - if not self.keyword_search: - raise ValueError("Keyword search not initialized") - return self.keyword_search.get_categories(field, k) def get_filter_fields(self): """Get the filter fields.""" diff --git a/denser_retriever/utils.py b/denser_retriever/utils.py index bae90fb..f5a1d9a 100644 --- a/denser_retriever/utils.py +++ b/denser_retriever/utils.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple import numpy as np -import pytrec_eval +# import pytrec_eval from scipy.sparse import csr_matrix from collections import defaultdict @@ -41,9 +41,10 @@ def evaluate( ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) recall_string = "recall." + ",".join([str(k) for k in k_values]) precision_string = "P." + ",".join([str(k) for k in k_values]) - evaluator = pytrec_eval.RelevanceEvaluator( - qrels, {map_string, ndcg_string, recall_string, precision_string} - ) + # evaluator = pytrec_eval.RelevanceEvaluator( + # qrels, {map_string, ndcg_string, recall_string, precision_string} + # ) + evaluator = None scores = evaluator.evaluate(results) for query_id in scores.keys(): diff --git a/denser_retriever/vectordb/milvus.py b/denser_retriever/vectordb/milvus.py index dff3af9..950ed43 100644 --- a/denser_retriever/vectordb/milvus.py +++ b/denser_retriever/vectordb/milvus.py @@ -109,7 +109,7 @@ def create_index( self.embeddings = embeddings self.source_max_length = 500 self.title_max_length = 500 - self.text_max_length = 8000 + self.text_max_length = 30000 self.field_max_length = 500 self.connection_args = self.connection_args or DEFAULT_MILVUS_CONNECTION @@ -207,7 +207,10 @@ def add_documents( doc.metadata.get("source", "")[: self.source_max_length - 10] ) titles.append(doc.metadata.get("title", "")[: self.title_max_length - 10]) - texts.append(doc.page_content[: self.text_max_length - 1000]) # buffer + truncated_text = doc.page_content[:10000] + if len(truncated_text) >= self.text_max_length: + print(f"Truncated text length: {len(truncated_text)} longer than {self.text_max_length}") + texts.append(truncated_text) pid_list.append(doc.metadata.get("pid", "-1")) for i, field_original_key in enumerate( diff --git a/examples/denser_search.py b/examples/denser_search.py index 6683261..b4c08f3 100644 --- a/examples/denser_search.py +++ b/examples/denser_search.py @@ -65,7 +65,8 @@ def denser_search(): format="MM.DD.YYYY", ) else: - categories = retriever.get_field_categories(field, 10) + # categories = retriever.get_field_categories(field, 10) + _, categories = retriever.retrieve("", 0, {}, True) ## TODO option = st.sidebar.selectbox( field, tuple(categories), diff --git a/examples/denser_search_cpws.py b/examples/denser_search_cpws.py index 76b9a79..db8b6d7 100644 --- a/examples/denser_search_cpws.py +++ b/examples/denser_search_cpws.py @@ -75,7 +75,8 @@ def denser_search(): format="MM.DD.YYYY", ) else: - categories = retriever.get_field_categories(field, 10) + # categories = retriever.get_field_categories(field, 10) + _, categories = retriever.retrieve("", 0, {}, True) option = st.sidebar.selectbox( field, tuple(categories), diff --git a/experiments/train_and_test.py b/experiments/train_and_test.py index 708e809..50ad683 100644 --- a/experiments/train_and_test.py +++ b/experiments/train_and_test.py @@ -3,6 +3,8 @@ import sys import json import shutil +import pickle +import time from langchain_core.documents import Document import xgboost as xgb @@ -14,16 +16,17 @@ ElasticKeywordSearch, create_elasticsearch_client, ) -from denser_retriever.reranker import HFReranker +from denser_retriever.reranker import HFReranker, CohereReranker from denser_retriever.retriever import DenserRetriever from denser_retriever.vectordb.milvus import MilvusDenserVectorDB -from denser_retriever.embeddings import VoyageAPIEmbeddings +from denser_retriever.embeddings import SentenceTransformerEmbeddings, VoyageAPIEmbeddings, BGEEmbeddings from experiments.hf_data_loader import HFDataLoader from experiments.denser_data import DenserData from denser_retriever.utils import ( evaluate, save_queries, save_qrels, + save_qrels_from_trec, load_qrels, docs_to_dict, ) @@ -43,37 +46,74 @@ "es+vs+rr_n": ["1,2,3,4,5,6,7,8,9", "2,5,8"], } +candidate_passage_paths = { + "lecardv1": "/home/ubuntu/lecardv1_data/candidates", + "lecardv2": "/home/ubuntu/lecardv2_data/candidate_55192", +} +# lecard v1: ajId is the ID of the case, ajName is the case name, ajjbqk is the basic facts of the case, +# pjjg is the case judgment, qw is the full content, writId is the unique ID of this document, +# and writName is the document name. +# lecard v2: where pid is the ID of the case, qw is the full content, fact is the basic facts of the case, +# reason is the analysis process of judges, result is the case judgment, charge is the criminal charge(s) of the case, +# article is the criminal law article of this case document. +passage_content_field = { + "lecardv1": "ajjbqk", + "lecardv2": "fact", +} class Experiment: def __init__(self, dataset_name, drop_old): data_name = os.path.basename(dataset_name) self.output_prefix = os.path.join("exps", f"exp_{data_name}") - self.ingest_bs = 2000 + self.ingest_bs = 100 # 2000 index_name = data_name.replace("-", "_") self.retriever = DenserRetriever( index_name=index_name, keyword_search=ElasticKeywordSearch( top_k=100, - es_connection=create_elasticsearch_client(url="http://localhost:9200"), - drop_old=drop_old + es_connection=create_elasticsearch_client(url="http://35.161.137.119:9200", + username="elastic", + password="+elwyyp0XWfDqIjhSvHP", + ), + drop_old=drop_old, + analysis="default" # default or ik ), vector_db=MilvusDenserVectorDB( top_k=100, - connection_args={"uri": "http://localhost:19530"}, + connection_args={"uri": "http://44.242.11.63:19530", + "user": "root", + "password": "Milvus"}, auto_id=True, drop_old=drop_old ), - reranker=HFReranker(model_name="jinaai/jina-reranker-v2-base-multilingual", top_k=100, - automodel_args={"torch_dtype": "float32"}, trust_remote_code=True), - embeddings=VoyageAPIEmbeddings(api_key="YOUR_API_KEY", - model_name="voyage-2", embedding_size=1024), + # vector_db=None, + # reranker=HFReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=100), + # reranker=HFReranker(model_name="jinaai/jina-reranker-v2-base-multilingual", top_k=100, + # automodel_args={"torch_dtype": "float32"}, trust_remote_code=True), + # reranker=HFReranker(model_name="BAAI/bge-reranker-v2-m3", top_k=100), + # reranker=HFReranker(model_name="BAAI/bge-reranker-base", top_k=80), + # reranker=HFReranker(model_name="BAAI/bge-reranker-large", top_k=100), + reranker=CohereReranker(api_key="vlHZEff43B6rZBCSuUOupiMSq95t5VkbymFLQygg", model_name="rerank-english-v3.0"), + # embeddings=SentenceTransformerEmbeddings( + # "Snowflake/snowflake-arctic-embed-m", 768, False + # ), + embeddings=VoyageAPIEmbeddings(api_key="pa-b76ti3S2pWuSl0go1S7f8-x150YAXUoh6UANO2LpHbI", model_name="voyage-2", embedding_size=1024), + # embeddings=VoyageAPIEmbeddings(api_key="pa-b76ti3S2pWuSl0go1S7f8-x150YAXUoh6UANO2LpHbI", model_name="voyage-law-2", embedding_size=1024), + # embeddings=SentenceTransformerEmbeddings( + # "chestnutlzj/ChatLaw-Text2Vec", 768, True + # ), + # embeddings=SentenceTransformerEmbeddings( + # "TencentBAC/Conan-embedding-v1", 1792, True + # ), + # embeddings=BGEEmbeddings(model_name="BAAI/bge-en-icl", embedding_size=4096), + # embeddings=None, gradient_boost=None ) self.max_query_size = 0 - self.max_query_len = 2000 + self.max_query_len = 4000 self.max_doc_size = 0 - self.max_doc_len = 8000 + self.max_doc_len = 4000 def ingest(self, dataset_name, split): exp_dir = os.path.join(self.output_prefix, split) @@ -81,12 +121,34 @@ def ingest(self, dataset_name, split): os.makedirs(exp_dir) passage_file = os.path.join(exp_dir, "passages.jsonl") - if dataset_name == 'anthropic_base': - copy_file('experiments/data/contextual-embeddings/data_base/passages.jsonl', passage_file, - self.max_doc_size) + if dataset_name in ["lecardv1", "lecardv2"]: + path = candidate_passage_paths[dataset_name] + content_field = passage_content_field[dataset_name] + + out = open(passage_file, "w") + num_docs = 0 + num_trimmed_docs = 0 + for root, dirs, files in os.walk(path): + for filename in files: + if filename.endswith(".json"): + filepath = os.path.join(root, filename) + with open(filepath, 'r') as json_file: + data = json.load(json_file) + page_content = data.pop(content_field).strip() + if self.max_doc_len > 0 and len(page_content) > self.max_doc_len: + page_content = page_content[:self.max_doc_len] + num_trimmed_docs += 1 + data["pid"] = filename.split(".")[0].strip() + doc = Document(page_content=page_content, metadata=data) + out.write(json.dumps(doc.dict(), ensure_ascii=False) + "\n") + num_docs += 1 + if self.max_doc_size > 0 and num_docs >= self.max_doc_size: + break + logger.info(f"Trimmed {num_trimmed_docs} from {num_docs} docs due to max_doc_len {self.max_doc_len}") + elif dataset_name == 'anthropic_base': + copy_file('experiments/data/contextual-embeddings/data_base/passages.jsonl', passage_file, self.max_doc_size) elif dataset_name == 'anthropic_context': - copy_file('experiments/data/contextual-embeddings/data_context/passages.jsonl', passage_file, - self.max_doc_size) + copy_file('experiments/data/contextual-embeddings/data_context/passages.jsonl', passage_file, self.max_doc_size) else: corpus, _, _ = HFDataLoader( hf_repo=dataset_name, @@ -96,7 +158,7 @@ def ingest(self, dataset_name, split): ).load(split=split) save_HF_corpus_as_docs( - corpus, passage_file, self.max_doc_size + corpus, passage_file, self.max_doc_size, self.max_doc_len ) out = open(passage_file, "r") @@ -120,7 +182,47 @@ def generate_feature_data(self, dataset_name, split): query_file = os.path.join(exp_dir, "queries.jsonl") qrels_file = os.path.join(exp_dir, "qrels.jsonl") - if dataset_name in ["anthropic_base", "anthropic_context"]: + if dataset_name == "lecardv1": + original_query_file = os.path.join("experiments/data/lecardv1/query.json") + queries = [] + num_trimmed_queries = 0 + with open(original_query_file, "r") as file: + for line in file: + query = json.loads(line) + query_str = query["q"] + if self.max_query_len > 0 and len(query_str) > self.max_query_len: + num_trimmed_queries += 1 + query_str = query_str[:self.max_query_len] + queries.append({"id": str(query["ridx"]), "text": query_str}) + logger.info( + f"Trimmed {num_trimmed_queries} from {len(queries)} queries due to max_query_len {self.max_query_len}") + if self.max_query_size > 0: + queries = queries[:self.max_query_size] + save_queries(queries, query_file) + with open("experiments/data/lecardv1/label_top30_dict.json", 'r') as infile: + qrels = json.load(infile) + save_qrels(qrels, qrels_file) + elif dataset_name == "lecardv2": + original_query_file = os.path.join("experiments/data/lecardv2", + "train_query.json" if split == "train" else "test_query.json") + queries = [] + num_trimmed_queries = 0 + with open(original_query_file, "r") as file: + for line in file: + query = json.loads(line) + # query_str = query["query"] + query_str = query["fact"] + if self.max_query_len > 0 and len(query_str) > self.max_query_len: + num_trimmed_queries += 1 + query_str = query_str[:self.max_query_len] + queries.append({"id": str(query["id"]), "text": query_str, "fact": query["fact"]}) + # TODO: use field "fact" + logger.info(f"Trimmed {num_trimmed_queries} from {len(queries)} queries due to max_query_len {self.max_query_len}") + if self.max_query_size > 0: + queries = queries[:self.max_query_size] + save_queries(queries, query_file) + qrels = save_qrels_from_trec("experiments/data/lecardv2/relevence.trec", qrels_file) + elif dataset_name in ["anthropic_base", "anthropic_context"]: shutil.copy('experiments/data/contextual-embeddings/data_context/queries.jsonl', query_file) shutil.copy('experiments/data/contextual-embeddings/data_context/qrels.jsonl', qrels_file) data = DenserData("experiments/data/contextual-embeddings/data_base") @@ -143,12 +245,15 @@ def generate_feature_data(self, dataset_name, split): if (self.max_query_size > 0 and i >= self.max_query_size): break logger.info(f"Processing query {i}") + query_str = q["text"] + if (self.max_query_len > 0 and len(query_str) > self.max_query_len): + query_str = query_str[:self.max_query_len] qid = q["id"] ks_docs = self.retriever.keyword_search.retrieve( - q["text"], self.retriever.keyword_search.top_k) + query_str, self.retriever.keyword_search.top_k) vs_docs = self.retriever.vector_db.similarity_search_with_score( - q["text"], self.retriever.vector_db.top_k) + query_str, self.retriever.vector_db.top_k) combined = [] seen = set() @@ -161,7 +266,7 @@ def generate_feature_data(self, dataset_name, split): reranked_docs = [] # import pdb; pdb.set_trace() if self.retriever.reranker: - reranked_docs = self.retriever.reranker.rerank(combined_docs, q["text"]) + reranked_docs = self.retriever.reranker.rerank(combined_docs, query_str) _, ks_score_dict, ks_rank_dict = docs_to_dict(ks_docs) @@ -254,7 +359,7 @@ def cross_validation_xgb(self, test_dir, retriever_config): groups = np.array(groups) # Prepare GroupKFold cross-validation - gkf = GroupKFold(n_splits=3) + gkf = GroupKFold(n_splits=5) # Initialize an array to hold all predictions x_data, y_data = load_svmlight_file(os.path.join(test_dir, retriever_config)) @@ -502,6 +607,8 @@ def report(self, eval_on, metric_str): # dataset = ["mteb/scifact", "train", "test"] # dataset = ["mteb/touche2020", "test", "test"] # dataset = ["mteb/trec-covid", "test", "test"] + # dataset = ["lecardv1", "test", "test"] + # dataset = ["lecardv2", "train", "test"] # dataset_name, train_on, eval_on = dataset # model_dir = "/home/ubuntu/denser_output_retriever/exp_msmarco/models/" @@ -514,14 +621,14 @@ def report(self, eval_on, metric_str): dataset_name = sys.argv[1] train_on = sys.argv[2] eval_on = sys.argv[3] - drop_old = True + drop_old = False experiment = Experiment(dataset_name, drop_old) if drop_old: experiment.ingest(dataset_name, train_on) # Generate retriever data, this takes time - experiment.generate_feature_data(dataset_name, train_on) - if eval_on != train_on: - experiment.generate_feature_data(dataset_name, eval_on) + # experiment.generate_feature_data(dataset_name, train_on) + # if eval_on != train_on: + # experiment.generate_feature_data(dataset_name, eval_on) experiment.compute_baselines(eval_on) if train_on == eval_on: experiment.cross_validation(eval_on) diff --git a/experiments/utils.py b/experiments/utils.py index a8b6fca..326931c 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -21,13 +21,15 @@ def copy_file(source_file, dest_file, top_k): else: break -def save_HF_corpus_as_docs(corpus, output_file: str, max_doc_size): +def save_HF_corpus_as_docs(corpus, output_file: str, max_doc_size, max_doc_len): out = open(output_file, "w") seen = set() for i, d in enumerate(corpus): if max_doc_size > 0 and i >= max_doc_size: break page_content = d.pop("text") + if max_doc_len > 0 and len(page_content) > max_doc_len: + page_content = page_content[:max_doc_len] d["pid"] = d.pop("id") assert d["pid"] not in seen seen.add(d["pid"]) diff --git a/pyproject.toml b/pyproject.toml index 384ef4b..b4da132 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,11 +39,12 @@ python = "^3.10.0" typer = {extras = ["all"], version = "^0.12.1"} rich = "^10.14.0" -pytrec-eval = "^0.5" +#pytrec-eval = "^0.5" sentence-transformers = "^2.7.0" # Specify the version of sentence-transformers -torch = [{markers = "sys_platform == 'darwin'", url = "https://download.pytorch.org/whl/cpu/torch-1.13.1-cp310-none-macosx_11_0_arm64.whl"}, -{markers = "sys_platform == 'linux'", url="https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp310-cp310-linux_x86_64.whl"}] +#torch = [{markers = "sys_platform == 'darwin'", url = "https://download.pytorch.org/whl/cpu/torch-1.13.1-cp310-none-macosx_11_0_arm64.whl"}, +torch = "^1.13.1" +#{markers = "sys_platform == 'linux'", url="https://download.pytorch.org/whl/cu117/torch-1.13.1%2Bcu117-cp310-cp310-linux_x86_64.whl"}] elasticsearch = "^8.13.0" pymilvus = "^2.4.4" datasets = "^2.18.0" diff --git a/tests/test_keyword.py b/tests/test_keyword.py index 5a33366..fdfd333 100644 --- a/tests/test_keyword.py +++ b/tests/test_keyword.py @@ -47,6 +47,3 @@ def test_get_index_mappings(self, keyword_search): assert "field1" in mappings assert "field2" in mappings - def test_get_categories(self, keyword_search): - categories = keyword_search.get_categories("field2") - assert len(categories) == 0 diff --git a/tests/test_retriever.py b/tests/test_retriever.py index bfde61a..43d03da 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -53,13 +53,6 @@ def test_get_field_categories(self): Document(page_content="content2", metadata={"title": "title2", "source": "source_test2"}), ] self.denser_retriever.ingest(docs) - field = "category_field" - k = 10 - categories = self.denser_retriever.get_field_categories(field, k) - assert isinstance(categories, list) - assert len(categories) <= k - for category in categories: - assert isinstance(category, str) def test_get_metadata_fields(self): docs = [