diff --git a/docs/source/conversation/index.rst b/docs/source/conversation/index.rst new file mode 100644 index 00000000..be4ff550 --- /dev/null +++ b/docs/source/conversation/index.rst @@ -0,0 +1,19 @@ +Conversation +============ + + +Learning +-------- + +.. autoxpmconfig:: xpmir.conversation.learning.DatasetConversationEntrySampler +.. autoxpmconfig:: xpmir.conversation.learning.reformulation.ConversationRepresentationEncoder +.. autoxpmconfig:: xpmir.conversation.learning.reformulation.DecontextualizedQueryConverter + + + + +CoSPLADE +-------- + +.. autoxpmconfig:: xpmir.conversation.models.cosplade.AsymetricMSEContextualizedRepresentationLoss +.. autoxpmconfig:: xpmir.conversation.models.cosplade.CoSPLADE diff --git a/docs/source/index.rst b/docs/source/index.rst index 5f23d29a..d71510f8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,9 +50,11 @@ Table of Contents evaluation learning/index letor/index + conversation/index neural hooks text/index + misc experiments papers/index pretrained diff --git a/docs/source/learning/optimization.rst b/docs/source/learning/optimization.rst index 78be1892..4f10f76a 100644 --- a/docs/source/learning/optimization.rst +++ b/docs/source/learning/optimization.rst @@ -33,8 +33,14 @@ Optimizers .. autoxpmconfig:: xpmir.learning.optim.RegexParameterFilter .. autoxpmconfig:: xpmir.learning.optim.OptimizationHook + +Hooks +***** + .. autoxpmconfig:: xpmir.learning.optim.GradientHook .. autoxpmconfig:: xpmir.learning.optim.GradientClippingHook +.. autoxpmconfig:: xpmir.learning.optim.GradientLogHook + Parameters ---------- @@ -94,4 +100,5 @@ Base classes .. autoxpmconfig:: xpmir.learning.base.Random .. autoxpmconfig:: xpmir.learning.base.Sampler +.. autoxpmconfig:: xpmir.learning.base.BaseSampler .. autoxpmconfig:: xpmir.learning.trainers.Trainer diff --git a/docs/source/letor/alignment.rst b/docs/source/letor/alignment.rst new file mode 100644 index 00000000..a3535572 --- /dev/null +++ b/docs/source/letor/alignment.rst @@ -0,0 +1,6 @@ +Alignment +********* + +.. autoxpmconfig:: xpmir.letor.trainers.alignment.AlignmentLoss +.. autoxpmconfig:: xpmir.letor.trainers.alignment.AlignmentTrainer +.. autoxpmconfig:: xpmir.letor.trainers.alignment.MSEAlignmentLoss diff --git a/docs/source/letor/index.rst b/docs/source/letor/index.rst index cc2f5b0f..72b5afe8 100644 --- a/docs/source/letor/index.rst +++ b/docs/source/letor/index.rst @@ -11,6 +11,7 @@ Learning to rank generative mlm generation + alignment Learning to rank is handled by various classes. Some are located @@ -85,3 +86,4 @@ Adapters .. autoxpmconfig:: xpmir.letor.samplers.hydrators.SampleTransform .. autoxpmconfig:: xpmir.letor.samplers.hydrators.SampleHydrator .. autoxpmconfig:: xpmir.letor.samplers.hydrators.SamplePrefixAdding +.. autoxpmconfig:: xpmir.letor.samplers.hydrators.SampleTransformList diff --git a/docs/source/misc.rst b/docs/source/misc.rst new file mode 100644 index 00000000..e67e8653 --- /dev/null +++ b/docs/source/misc.rst @@ -0,0 +1 @@ +.. autoxpmconfig:: xpmir.utils.convert.Converter diff --git a/docs/source/text/huggingface.rst b/docs/source/text/huggingface.rst index 1800be70..820d20f9 100644 --- a/docs/source/text/huggingface.rst +++ b/docs/source/text/huggingface.rst @@ -26,6 +26,7 @@ Tokenizers .. autoxpmconfig:: xpmir.text.huggingface.tokenizers.HFListTokenizer .. autoxpmconfig:: xpmir.text.huggingface.tokenizers.HFStringTokenizer +.. autoxpmconfig:: xpmir.text.huggingface.tokenizers.HFTokenizerAdapter Encoders -------- diff --git a/requirements.txt b/requirements.txt index e8224c4a..bc698621 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ experimaestro>=1.5.0 datamaestro>=1.0.0 -datamaestro_text>=2024.2.27.1 +datamaestro_text>=2024.2.27.2 ir_datasets docstring_parser xpmir_rust == 0.20.* diff --git a/src/xpmir/conversation/models/cosplade.py b/src/xpmir/conversation/models/cosplade.py index 6553091c..81e53d45 100644 --- a/src/xpmir/conversation/models/cosplade.py +++ b/src/xpmir/conversation/models/cosplade.py @@ -89,9 +89,9 @@ def forward(self, records: List[TopicConversationRecord]): for ix, c_record in enumerate(records): # Adds q_n, q_1, ..., q_{n-1} queries.append( - [c_record[TextItem].get_text()] + [c_record[TextItem].text] + [ - entry[TextItem].get_text() + entry[TextItem].text for entry in c_record[ConversationHistoryItem].history if isinstance(entry, TopicRecord) ] @@ -105,7 +105,7 @@ def forward(self, records: List[TopicConversationRecord]): ): if isinstance(item, TopicRecord) and answer is not None: query_answer_pairs.append( - (item[TextItem].get_text(), answer[AnswerEntry].answer) + (item[TextItem].text, answer[AnswerEntry].answer) ) pair_origins.append(ix) elif isinstance(item, AnswerConversationRecord): diff --git a/src/xpmir/datasets/adapters.py b/src/xpmir/datasets/adapters.py index dc9130db..52c9dc37 100644 --- a/src/xpmir/datasets/adapters.py +++ b/src/xpmir/datasets/adapters.py @@ -229,7 +229,7 @@ def execute(self): with self.topics.open("wt") as fp: for topic in topics: ids.add(topic[IDItem].id) - fp.write(f"""{topic[IDItem].id}\t{topic[TextItem].get_text()}\n""") + fp.write(f"""{topic[IDItem].id}\t{topic[TextItem].text}\n""") with self.assessments.open("wt") as fp: for qrels in self.dataset.assessments.iter(): @@ -367,7 +367,7 @@ def execute(self): logger.warning( "Skipping topic %s [%s], (no assessment)", topic[IDItem].id, - topic[TextItem].get_text(), + topic[TextItem].text, ) continue @@ -391,7 +391,7 @@ def execute(self): # don't need to worry about the threshold here for retriever in self.retrievers: docids.update( - sd.document.get_id() for sd in retriever.retrieve(topic.text) + sd.document[IDItem].id for sd in retriever.retrieve(topic.text) ) # Write the document IDs diff --git a/src/xpmir/documents/samplers.py b/src/xpmir/documents/samplers.py index 729af9e0..3510b091 100644 --- a/src/xpmir/documents/samplers.py +++ b/src/xpmir/documents/samplers.py @@ -3,7 +3,7 @@ from experimaestro import Param, Config import torch import numpy as np -from datamaestro_text.data.ir import DocumentStore +from datamaestro_text.data.ir import DocumentStore, TextItem from datamaestro_text.data.ir.base import ( SimpleTextTopicRecord, SimpleTextDocumentRecord, @@ -139,11 +139,11 @@ def iter(random: np.random.RandomState): while True: record_pos_qry = next(iter) - text_pos_qry = record_pos_qry.text + text_pos_qry = record_pos_qry[TextItem].text spans_pos_qry = self.get_text_span(text_pos_qry, random) record_neg = next(iter) - text_neg = record_neg.text + text_neg = record_neg[TextItem].text spans_neg = self.get_text_span(text_neg, random) if not (spans_pos_qry and spans_neg): diff --git a/src/xpmir/evaluation.py b/src/xpmir/evaluation.py index 560d6756..4797eac3 100644 --- a/src/xpmir/evaluation.py +++ b/src/xpmir/evaluation.py @@ -73,10 +73,7 @@ def print_line(fp, measure, scope, value): def get_run(retriever: Retriever, dataset: Adhoc): """Returns the scored documents for each topic in a dataset""" results = retriever.retrieve_all( - { - topic[IDItem].id: topic[TextItem].get_text() - for topic in dataset.topics.iter() - } + {topic[IDItem].id: topic[TextItem].text for topic in dataset.topics.iter()} ) return { qid: {sd.document[IDItem].id: sd.score for sd in scoredocs} diff --git a/src/xpmir/index/faiss.py b/src/xpmir/index/faiss.py index a75fa92e..0b623065 100644 --- a/src/xpmir/index/faiss.py +++ b/src/xpmir/index/faiss.py @@ -11,7 +11,7 @@ import numpy as np from experimaestro import Annotated, Meta, Task, pathgenerator, Param, tqdm import logging -from datamaestro_text.data.ir import DocumentStore +from datamaestro_text.data.ir import DocumentStore, TextItem from xpmir.rankers import Retriever, ScoredDocument from xpmir.learning.batchers import Batcher from xpmir.learning import ModuleInitMode @@ -175,7 +175,9 @@ def batch_encoder(doc_iter: Iterator[str]): with torch.no_grad(): for batch in batchiter(self.batchsize, doc_iter): batcher.process( - [document.text for document in batch], self.index_documents, index + [document[TextItem].text for document in batch], + self.index_documents, + index, ) logging.info("Writing FAISS index (%d documents)", index.ntotal) @@ -192,7 +194,7 @@ def encode(self, batch: List[str], data: List): data.append(x) def index_documents(self, batch: List[str], index): - x = self.encoder(batch) + x = self.encoder(batch).value if self.normalize: x /= x.norm(2, keepdim=True, dim=1) index.add(np.ascontiguousarray(x.cpu().numpy())) @@ -222,7 +224,7 @@ def retrieve(self, query: TopicRecord) -> List[ScoredDocument]: """Retrieves a documents, returning a list sorted by decreasing score""" with torch.no_grad(): self.encoder.eval() # pass the model to the evaluation model - encoded_query = self.encoder([query]) + encoded_query = self.encoder([query[TextItem].text]).value if self.index.normalize: encoded_query /= encoded_query.norm(2) diff --git a/src/xpmir/index/sparse.py b/src/xpmir/index/sparse.py index 1b8e7cc1..5b3c62d3 100644 --- a/src/xpmir/index/sparse.py +++ b/src/xpmir/index/sparse.py @@ -86,7 +86,8 @@ def reducer( progress, ): for (key, _), vector in zip( - batch, self.encoder([text for _, text in batch]).cpu().detach().numpy() + batch, + self.encoder([text for _, text in batch]).value.cpu().detach().numpy(), ): (ix,) = vector.nonzero() query = {ix: float(v) for ix, v in zip(ix, vector[ix])} @@ -114,7 +115,7 @@ def retrieve(self, query: TopicRecord, top_k=None) -> List[ScoredDocument]: """ # Build up iterators - vector = self.encoder([query])[0].cpu().detach().numpy() + vector = self.encoder([query]).value[0].cpu().detach().numpy() (ix,) = vector.nonzero() # ix represents the position without 0 in the vector query = { ix: float(v) for ix, v in zip(ix, vector[ix]) @@ -215,7 +216,7 @@ def execute(self): def encode_documents(self, batch: List[Tuple[int, DocumentRecord]]): # Assumes for now dense vectors vectors = ( - self.encoder([d[TextItem].get_text() for _, d in batch]).value.cpu().numpy() + self.encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy() ) # bs * vocab for vector, (docid, _) in zip(vectors, batch): (nonzero_ix,) = vector.nonzero() diff --git a/src/xpmir/interfaces/anserini.py b/src/xpmir/interfaces/anserini.py index 0bbf4abb..9710027c 100644 --- a/src/xpmir/interfaces/anserini.py +++ b/src/xpmir/interfaces/anserini.py @@ -132,7 +132,7 @@ def _generator(out): json.dump( { "id": document[IDItem].id, - "contents": document[TextItem].get_text(), + "contents": document[TextItem].text, }, out, ) diff --git a/src/xpmir/letor/records.py b/src/xpmir/letor/records.py index a26621aa..116f41ad 100644 --- a/src/xpmir/letor/records.py +++ b/src/xpmir/letor/records.py @@ -426,6 +426,6 @@ def from_texts( def to_texts(self) -> List[str]: texts = [] for doc in self.documents: - texts.append(doc.document[TextItem].get_text()) + texts.append(doc.document[TextItem].text) return texts diff --git a/src/xpmir/letor/samplers/__init__.py b/src/xpmir/letor/samplers/__init__.py index 327d84a6..29262f06 100644 --- a/src/xpmir/letor/samplers/__init__.py +++ b/src/xpmir/letor/samplers/__init__.py @@ -140,7 +140,7 @@ def document(self, doc_id): return self._store.document_ext(doc_id) def document_text(self, doc_id): - return self.document(doc_id).get_text() + return self.document(doc_id).text @cache("run") def _itertopics( @@ -175,11 +175,11 @@ def _itertopics( # Retrieve documents skipped = 0 for query in tqdm(queries): - qassessments = assessments.get(query.get_id(), None) + qassessments = assessments.get(query[IDItem].id, None) if not qassessments: skipped += 1 self.logger.warning( - "Skipping topic %s (no assessments)", query.get_id() + "Skipping topic %s (no assessments)", query[IDItem].id ) continue @@ -188,41 +188,43 @@ def _itertopics( for docno, rel in qassessments.items(): if rel > 0: fp.write( - f"{query.get_text() if not positives else ''}" + f"{query.text if not positives else ''}" f"\t{docno}\t0.\t{rel}\n" ) positives.append((docno, rel, 0)) if not positives: self.logger.warning( - "Skipping topic %s (no relevant documents)", query.get_id() + "Skipping topic %s (no relevant documents)", + query[IDItem].id, ) skipped += 1 continue scoreddocuments: List[ScoredDocument] = self.retriever.retrieve( - query.get_text() + query.text ) negatives = [] for rank, sd in enumerate(scoreddocuments): # Get the assessment (assumes not relevant) - rel = qassessments.get(sd.document.get_id(), 0) + rel = qassessments.get(sd.document[IDItem].id, 0) if rel > 0: continue - negatives.append((sd.document.get_id(), rel, sd.score)) - fp.write(f"\t{sd.document.get_id()}\t{sd.score}\t{rel}\n") + negatives.append((sd.document[IDItem].id, rel, sd.score)) + fp.write(f"\t{sd.document[IDItem].id}\t{sd.score}\t{rel}\n") if not negatives: self.logger.warning( - "Skipping topic %s (no negatives documents)", query.get_id() + "Skipping topic %s (no negatives documents)", + query[IDItem].id, ) skipped += 1 continue assert len(positives) > 0 and len(negatives) > 0 - yield query.get_text(), positives, negatives + yield query.text, positives, negatives # Finally, move the cache file in place... self.logger.info( @@ -332,7 +334,7 @@ def sample(self, samples: List[Tuple[str, int, float]]): while text is None: docid, rel, score = samples[self.random.randint(0, len(samples))] document = self.document(docid).add(ScoredItem(score)) - text = document[TextItem].get_text() + text = document[TextItem].text return document def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]: @@ -616,7 +618,7 @@ def execute(self): positives = [] negatives = [] scoreddocuments: List[ScoredDocument] = self.retriever.retrieve( - query.get_text() + query.text ) for rank, sd in enumerate(scoreddocuments): diff --git a/src/xpmir/letor/samplers/hydrators.py b/src/xpmir/letor/samplers/hydrators.py index aa818135..194265cb 100644 --- a/src/xpmir/letor/samplers/hydrators.py +++ b/src/xpmir/letor/samplers/hydrators.py @@ -9,7 +9,6 @@ from xpmir.letor.records import ( PairwiseRecords, PairwiseRecord, - DocumentRecord, ) from xpmir.utils.iter import ( SerializableIterator, @@ -45,7 +44,9 @@ def transform_topics(self, topics: List[ir.TopicRecord]): if self.querystore is None: return None return [ - ir.GenericTopic(topic.get_id(), self.querystore[topic.get_id()]) + ir.GenericTopicRecord.create( + topic[IDItem].id, self.querystore[topic[IDItem].id] + ) for topic in topics ] @@ -72,13 +73,11 @@ def transform_topics( if isinstance(topics[0], ir.GenericTopic): return [ - ir.GenericTopic(topic.get_id(), self.query_prefix + topic.get_text()) + ir.GenericTopic(topic[IDItem].id, self.query_prefix + topic.text) for topic in topics ] elif isinstance(topics[0], ir.TextTopic): - return [ - ir.TextTopic(self.query_prefix + topic.get_text()) for topic in topics - ] + return [ir.TextTopic(self.query_prefix + topic.text) for topic in topics] def transform_documents( self, documents: List[ir.DocumentRecord] @@ -89,13 +88,13 @@ def transform_documents( if isinstance(documents[0], ir.GenericDocument): return [ ir.GenericDocument( - document.get_id(), self.document_prefix + document.get_text() + document[IDItem].id, self.document_prefix + document.text ) for document in documents ] elif isinstance(documents[0], ir.TextDocument): return [ - ir.TextDocument(self.document_prefix + document.get_text()) + ir.TextDocument(self.document_prefix + document.text) for document in documents ] @@ -139,15 +138,13 @@ def initialize(self, random: Optional[np.random.RandomState] = None): self.sampler.initialize(random) def transform_record(self, record: PairwiseRecord) -> PairwiseRecord: - topics = [record.query.topic] - docs = [record.positive.document, record.negative.document] + topics = [record.query] + docs = [record.positive, record.negative] topics = self.adapter.transform_topics(topics) or topics docs = self.adapter.transform_documents(docs) or docs - return PairwiseRecord( - topics[0].as_record(), DocumentRecord(docs[0]), DocumentRecord(docs[1]) - ) + return PairwiseRecord(topics[0], docs[0], docs[1]) def pairwise_iter(self) -> Iterator[PairwiseRecord]: iterator = self.sampler.pairwise_iter() diff --git a/src/xpmir/mlm/samplers.py b/src/xpmir/mlm/samplers.py index 0c9044fe..1bb9db10 100644 --- a/src/xpmir/mlm/samplers.py +++ b/src/xpmir/mlm/samplers.py @@ -44,12 +44,9 @@ def iter(random: np.random.RandomState): document = self.datasets[choice].document_int( self.random.randint(0, self.datasets[choice].documentcount) ) - yield DocumentRecord(document.get_id(), document.get_text()) + yield document else: # FIXME: it makes the iter not fully serializable - yield from ( - DocumentRecord(doc.get_id(), doc.get_text()) - for doc in self.datasets[choice].iter() - ) + yield from self.datasets[choice].iter() return RandomSerializableIterator(self.random, iter) diff --git a/src/xpmir/neural/__init__.py b/src/xpmir/neural/__init__.py index 433f04ea..e28b77fb 100644 --- a/src/xpmir/neural/__init__.py +++ b/src/xpmir/neural/__init__.py @@ -2,6 +2,7 @@ import itertools from typing import Iterable, Union, List, Optional, TypeVar, Generic import torch +from datamaestro_text.data.ir import TextItem from xpmir.learning.batchers import Sliceable from xpmir.learning.context import TrainerContext from xpmir.letor.records import BaseRecords, ProductRecords, TopicRecord, DocumentRecord @@ -51,7 +52,7 @@ def encode_documents(self, records: Iterable[DocumentRecord]) -> DocsRep: """Encode a list of texts (document or query) The return value is model dependent""" - return self.encode([record.document.get_text() for record in records]) + return self.encode([record[TextItem].text for record in records]) def encode_queries(self, records: Iterable[TopicRecord]) -> QueriesRep: """Encode a list of texts (document or query) @@ -60,7 +61,7 @@ def encode_queries(self, records: Iterable[TopicRecord]) -> QueriesRep: By default, uses `merge` """ - return self.encode([record.topic.get_text() for record in records]) + return self.encode([record[TextItem].text for record in records]) def merge_queries(self, queries: QueriesRep): """Merge query batches encoded with `encode_queries` diff --git a/src/xpmir/neural/cross.py b/src/xpmir/neural/cross.py index 5e09f808..a3949244 100644 --- a/src/xpmir/neural/cross.py +++ b/src/xpmir/neural/cross.py @@ -39,13 +39,13 @@ def __validate__(self): def __initialize__(self, options): super().__initialize__(options) self.encoder.initialize(options) - self.classifier = torch.nn.Linear(self.encoder.dimension(), 1) + self.classifier = torch.nn.Linear(self.encoder.dimension, 1) def forward(self, inputs: BaseRecords, info: TrainerContext = None): # Encode queries and documents pairs = self.encoder( [ - (tr[TextItem].get_text(), dr[TextItem].get_text()) + (tr[TextItem].text, dr[TextItem].text) for tr, dr in zip(inputs.topics, inputs.documents) ] ) # shape (batch_size * dimension) diff --git a/src/xpmir/neural/dual.py b/src/xpmir/neural/dual.py index 0627c606..06d63ee4 100644 --- a/src/xpmir/neural/dual.py +++ b/src/xpmir/neural/dual.py @@ -1,6 +1,7 @@ from typing import List, Optional import torch from experimaestro import Param +from datamaestro_text.data.ir import TextItem from xpmir.distributed import DistributableModel from xpmir.learning.batchers import Batcher from xpmir.letor.records import TopicRecord, DocumentRecord @@ -101,12 +102,12 @@ class CosineDense(Dense): def encode_queries(self, records: List[TopicRecord]): queries = (self.query_encoder or self.encoder)( - [record.topic.get_text() for record in records] + [record[TextItem].text for record in records] ) return queries / queries.norm(dim=1, keepdim=True) def encode_documents(self, records: List[DocumentRecord]): - documents = self.encoder([record.document.get_text() for record in records]) + documents = self.encoder([record[TextItem].text for record in records]) return documents / documents.norm(dim=1, keepdim=True) @@ -119,11 +120,11 @@ def __validate__(self): def encode_queries(self, records: List[TopicRecord]): """Encode the different queries""" - return self._query_encoder([record.topic.get_text() for record in records]) + return self._query_encoder([record[TextItem].text for record in records]) def encode_documents(self, records: List[DocumentRecord]): """Encode the different documents""" - return self.encoder([record.document.get_text() for record in records]) + return self.encoder([record[TextItem].text for record in records]) def getRetriever( self, retriever: "Retriever", batch_size: int, batcher: Batcher, device=None diff --git a/src/xpmir/neural/generative/cross.py b/src/xpmir/neural/generative/cross.py index 67d83ab8..936c12b1 100644 --- a/src/xpmir/neural/generative/cross.py +++ b/src/xpmir/neural/generative/cross.py @@ -1,4 +1,5 @@ from experimaestro import Param, Constant +from datamaestro_text.data.ir import TextItem from xpmir.neural.generative import ConditionalGenerator from xpmir.letor.records import ( BaseRecords, @@ -37,9 +38,7 @@ def __initialize__(self, options): def forward(self, inputs: BaseRecords, info: TrainerContext = None): # Encode queries and documents inputs = [ - self.pattern.format( - query=tr.topic.get_text(), document=dr.document.get_text() - ) + self.pattern.format(query=tr[TextItem].text, document=dr[TextItem].text) for tr, dr in zip(inputs.topics, inputs.documents) ] diff --git a/src/xpmir/neural/interaction/__init__.py b/src/xpmir/neural/interaction/__init__.py index 84211054..c64b8198 100644 --- a/src/xpmir/neural/interaction/__init__.py +++ b/src/xpmir/neural/interaction/__init__.py @@ -2,8 +2,8 @@ from typing import Iterable, Optional, List import torch from experimaestro import Param +from datamaestro_text.data.ir import TextItem from xpmir.learning.context import TrainerContext - from xpmir.neural.dual import ( DualVectorScorer, TopicRecord, @@ -62,7 +62,7 @@ def _encode( def encode_documents(self, records: Iterable[DocumentRecord]) -> SimilarityInput: return self.similarity.preprocess( self._encode( - [record.document.get_text() for record in records], + [record[TextItem].text for record in records], self.encoder, TokenizerOptions(self.dlen), ) @@ -71,7 +71,7 @@ def encode_documents(self, records: Iterable[DocumentRecord]) -> SimilarityInput def encode_queries(self, records: Iterable[TopicRecord]) -> SimilarityInput: return self.similarity.preprocess( self._encode( - [record.topic.get_text() for record in records], + [record[TextItem].text for record in records], self._query_encoder, TokenizerOptions(self.qlen), ) diff --git a/src/xpmir/neural/modules/interaction_matrix.py b/src/xpmir/neural/modules/interaction_matrix.py index b494881e..bfddc5b9 100644 --- a/src/xpmir/neural/modules/interaction_matrix.py +++ b/src/xpmir/neural/modules/interaction_matrix.py @@ -1,7 +1,5 @@ import torch from torch import nn -from xpmir.letor.records import BaseRecords -from xpmir.text import TokensEncoder # The code below is heavily borrowed from OpenNIR @@ -60,15 +58,3 @@ def wrap_list(x): b_mask = (b_tok.reshape(BAT, 1, B) != self.padding).float() simmats.append(cos_simmat(a_emb, b_emb, a_mask, b_mask)) return torch.stack(simmats, dim=1) - - def encode_query_doc( - self, encoder: TokensEncoder, inputs: BaseRecords, d_maxlen=None, q_maxlen=None - ): - """Returns a (batch x ... x #q x #d) tensor""" - tokq, q, tokd, d = encoder.enc_query_doc( - [q.topic.get_text() for q in inputs.queries], - [d.document.get_text() for d in inputs.documents], - d_maxlen=d_maxlen, - q_maxlen=q_maxlen, - ) - return self(q, d, tokq.ids, tokd.ids), tokq, tokd diff --git a/src/xpmir/rankers/__init__.py b/src/xpmir/rankers/__init__.py index ffa56596..1c669f1f 100644 --- a/src/xpmir/rankers/__init__.py +++ b/src/xpmir/rankers/__init__.py @@ -19,7 +19,12 @@ import torch.nn as nn import attrs from experimaestro import Param, Config, Meta -from datamaestro_text.data.ir import Documents, DocumentStore, SimpleTextTopicRecord +from datamaestro_text.data.ir import ( + Documents, + DocumentStore, + SimpleTextTopicRecord, + IDItem, +) from datamaestro_text.data.ir.base import DocumentRecord from xpmir.utils.utils import Initializable from xpmir.letor import Device, Random @@ -506,6 +511,6 @@ def initialize(self): def retrieve(self, record: TopicRecord) -> List[ScoredDocument]: return [ - ScoredDocument(self.store.document_ext(sd.document.get_id()), sd.score) + ScoredDocument(self.store.document_ext(sd.document[IDItem].id), sd.score) for sd in self.retriever.retrieve(record) ] diff --git a/src/xpmir/rankers/full.py b/src/xpmir/rankers/full.py index cfae228a..2027fcf5 100644 --- a/src/xpmir/rankers/full.py +++ b/src/xpmir/rankers/full.py @@ -91,7 +91,7 @@ def score( per query) """ # Encode documents - encoded = self.scorer.encode_documents(DocumentRecord(d) for d in documents) + encoded = self.scorer.encode_documents(documents) # Process query by query new_scores = [[] for _ in documents] diff --git a/src/xpmir/test/index/test_faiss.py b/src/xpmir/test/index/test_faiss.py index cb828088..143cd477 100644 --- a/src/xpmir/test/index/test_faiss.py +++ b/src/xpmir/test/index/test_faiss.py @@ -4,6 +4,7 @@ import pytest from experimaestro import ObjectStore from experimaestro.xpmutils import DirectoryContext +from datamaestro_text.data.ir import TextItem, IDItem from xpmir.documents.samplers import HeadDocumentSampler from xpmir.index.faiss import FaissRetriever, IndexBackedFaiss from xpmir.test.utils.utils import SampleDocumentStore, SparseRandomTextEncoder @@ -43,15 +44,15 @@ def test_faiss_indexation(tmp_path: Path, indexspec): retriever.initialize() documents = builder_instance.documents.documents - x_docs = retriever.encoder([d.text for d in documents.values()]) + x_docs = retriever.encoder([d[TextItem].text for d in documents.values()]).value scores = x_docs @ x_docs.T for ix, document in enumerate(documents.values()): - scoredDocuments = retriever.retrieve(document.text) + scoredDocuments = retriever.retrieve(document) scoredDocuments.sort(reverse=True) expected = list(scores[ix].sort(descending=True).indices[:topk].numpy()) - logging.warning("%s vs %s", scores[ix], scoredDocuments) - observed = [int(sd.document.get_id()) for sd in scoredDocuments] + logging.debug("%s vs %s", scores[ix], scoredDocuments) + observed = [int(sd.document[IDItem].id) for sd in scoredDocuments] assert expected == observed diff --git a/src/xpmir/test/index/test_sparse.py b/src/xpmir/test/index/test_sparse.py index 4a161966..430253cf 100644 --- a/src/xpmir/test/index/test_sparse.py +++ b/src/xpmir/test/index/test_sparse.py @@ -1,9 +1,10 @@ -from experimaestro import ObjectStore -from experimaestro.xpmutils import DirectoryContext -from pathlib import Path import pytest import torch import numpy as np +from pathlib import Path +from datamaestro_text.data.ir import TextItem, IDItem +from experimaestro import ObjectStore +from experimaestro.xpmutils import DirectoryContext from xpmir.index.sparse import SparseRetriever, SparseRetrieverIndexBuilder from xpmir.test.utils.utils import SampleDocumentStore, SparseRandomTextEncoder @@ -37,8 +38,8 @@ def __init__(self, context, ordered_index: bool = False): self.document_store = builder_instance.documents self.x_docs = builder_instance.encoder( - [d.text for d in self.document_store.documents.values()] - ) + [d[TextItem].text for d in self.document_store.documents.values()] + ).value # Check index self.index = builder.task_outputs(lambda x: x) @@ -99,7 +100,7 @@ def test_sparse_retrieve(sparse_index: SparseIndex, retriever): document = sparse_index.document_store.document_int(ix) # Use the retriever - scoredDocuments = retriever.retrieve(document.get_text()) + scoredDocuments = retriever.retrieve(document[TextItem].text) # scoredDocuments.sort(reverse=True) # scoredDocuments = scoredDocuments[:retriever.topk] @@ -109,7 +110,7 @@ def test_sparse_retrieve(sparse_index: SparseIndex, retriever): indices = sorted.indices[: retriever.topk] expected = list(indices.numpy()) - observed = [int(sd.document.get_id()) for sd in scoredDocuments] + observed = [int(sd.document[IDItem].id) for sd in scoredDocuments] expected_scores = sorted.values[: retriever.topk].numpy() observed_scores = np.array([float(sd.score) for sd in scoredDocuments]) @@ -137,8 +138,8 @@ def test_sparse_retrieve_all(retriever): for key, query in queries.items(): query_results = retriever.retrieve(query) - observed = [d.document.get_id() for d in all_results[key]] - expected = [d.document.get_id() for d in query_results] + observed = [d.document[IDItem].id for d in all_results[key]] + expected = [d.document[IDItem].id for d in query_results] assert observed == expected observed_scores = [d.score for d in all_results[key]] diff --git a/src/xpmir/test/letor/test_samplers.py b/src/xpmir/test/letor/test_samplers.py index 321ae493..7ae8855b 100644 --- a/src/xpmir/test/letor/test_samplers.py +++ b/src/xpmir/test/letor/test_samplers.py @@ -33,6 +33,9 @@ def iter( 2, f"doc-{count}" ) + topic_recordtype = ir.SimpleTextTopicRecord + document_recordtype = ir.GenericDocumentRecord + def test_serializing_tripletbasedsampler(): """Serialized samplers should start back from the saved state""" @@ -57,13 +60,9 @@ def test_serializing_tripletbasedsampler(): iter = sampler.pairwise_iter() iter.load_state_dict(data) for _, record, expected in zip(range(10), iter, samples): - assert expected.query.topic.get_text() == record.query.topic.get_text() - assert ( - expected.positive.document.get_text() == record.positive.document.get_text() - ) - assert ( - expected.negative.document.get_text() == record.negative.document.get_text() - ) + assert expected.query[ir.TextItem].text == record.query[ir.TextItem].text + assert expected.positive[ir.TextItem].text == record.positive[ir.TextItem].text + assert expected.negative[ir.TextItem].text == record.negative[ir.TextItem].text class GeneratedDocuments(ir.Documents): @@ -131,6 +130,6 @@ def test_pairwise_randomspansampler(): for s1, s2, _ in zip(iter1, iter2, range(10)): # check that they are the same with same random state - assert s1.query.topic.get_text() == s2.query.topic.get_text() - assert s1.positive.document.get_text() == s2.positive.document.get_text() - assert s1.negative.document.get_text() == s2.negative.document.get_text() + assert s1.query[ir.TextItem].text == s2.query[ir.TextItem].text + assert s1.positive[ir.TextItem].text == s2.positive[ir.TextItem].text + assert s1.negative[ir.TextItem].text == s2.negative[ir.TextItem].text diff --git a/src/xpmir/test/letor/test_samplers_hydrator.py b/src/xpmir/test/letor/test_samplers_hydrator.py index 52020ec4..b1a6adb4 100644 --- a/src/xpmir/test/letor/test_samplers_hydrator.py +++ b/src/xpmir/test/letor/test_samplers_hydrator.py @@ -26,6 +26,14 @@ def iter( ), ir.IDDocumentRecord.from_id(str(2 * count + 1)) count += 1 + @property + def topic_recordtype(self): + return ir.IDTopicRecord + + @property + def document_recordtype(self): + return ir.IDDocumentRecord + class FakeTextStore(TextStore): def __getitem__(self, key: str) -> str: @@ -50,12 +58,12 @@ def test_pairwise_hydrator(): h_sampler.instance() for record, n in zip(h_sampler.pairwise_iter(), range(5)): - assert record.query.topic.get_text() == f"T{n}" - assert record.positive.document.get_text() == f"D{2*n}" - assert record.negative.document.get_text() == f"D{2*n+1}" + assert record.query[ir.TextItem].text == f"T{n}" + assert record.positive[ir.TextItem].text == f"D{2*n}" + assert record.negative[ir.TextItem].text == f"D{2*n+1}" batch_it = h_sampler.pairwise_batch_iter(3) for record, n in zip(itertools.chain(next(batch_it), next(batch_it)), range(5)): - assert record.query.topic.get_text() == f"T{n}" - assert record.positive.document.get_text() == f"D{2*n}" - assert record.negative.document.get_text() == f"D{2*n+1}" + assert record.query[ir.TextItem].text == f"T{n}" + assert record.positive[ir.TextItem].text == f"D{2*n}" + assert record.negative[ir.TextItem].text == f"D{2*n+1}" diff --git a/src/xpmir/test/neural/test_forward.py b/src/xpmir/test/neural/test_forward.py index 4fd95cac..f7bc0838 100644 --- a/src/xpmir/test/neural/test_forward.py +++ b/src/xpmir/test/neural/test_forward.py @@ -5,6 +5,7 @@ import torch from collections import defaultdict from experimaestro import Constant +from datamaestro_text.data.ir import TextItem from xpmir.index import Index from xpmir.learning import Random, ModuleInitMode from xpmir.neural.dual import CosineDense, DotDense @@ -22,6 +23,7 @@ DualTextEncoder, TokensRepresentationOutput, TokenizerOptions, + RepresentationOutput, ) from xpmir.text.adapters import MeanTextEncoder @@ -63,7 +65,7 @@ def forward(self, texts: List[str], options=None): tok_texts = self.tokenizer.batch_tokenize( texts, maxlen=options.max_length, mask=True ) - return TokensRepresentationOutput(tok_texts, self.embed(tok_texts.ids)) + return TokensRepresentationOutput(self.embed(tok_texts.ids), tok_texts) def static(self) -> bool: return False @@ -150,12 +152,12 @@ def static(self): return False def forward(self, texts: List[Tuple[str, str]]): - return torch.cat([self.cache[text] for text in texts]) + return RepresentationOutput(torch.cat([self.cache[text] for text in texts])) @registermodel -def joint(): - """Joint classifier factory""" +def cross_scorer(): + """Cross-scorer classifier factory""" from xpmir.neural.cross import CrossScorer return CrossScorer(encoder=DummyDualTextEncoder()).instance() @@ -245,7 +247,7 @@ def test_forward_consistency(modelfactory, inputfactoriescouple): outputs.append(model(input, None)) maps.append( { - (qr.topic.get_text(), dr.document.get_text()): ix + (qr[TextItem].text, dr[TextItem].text): ix for ix, (qr, dr) in enumerate(zip(input.queries, input.documents)) } ) diff --git a/src/xpmir/test/rankers/test_full.py b/src/xpmir/test/rankers/test_full.py index 13a24800..88234f0f 100644 --- a/src/xpmir/test/rankers/test_full.py +++ b/src/xpmir/test/rankers/test_full.py @@ -6,6 +6,7 @@ import torch from experimaestro.notifications import TaskEnv +from datamaestro_text.data.ir import SimpleTextTopicRecord, TextItem, IDItem from xpmir.learning.context import TrainerContext from xpmir.letor.records import TopicRecord @@ -65,7 +66,9 @@ class _FullRetrieverRescorer(FullRetrieverRescorer): def retrieve(self, record: TopicRecord): scored_documents = [ # Randomly get a score (and cache it) - ScoredDocument(d, self.scorer.cache(record.topic.get_text(), d.get_text())) + ScoredDocument( + d, self.scorer.cache(record[TextItem].text, d[TextItem].text) + ) for d in self.documents ] scored_documents.sort(reverse=True) @@ -86,7 +89,10 @@ def test_fullretrieverescorer(tmp_path: Path): # Retrieve normally scoredDocuments = {} - queries = {qid: TopicRecord.from_text(f"Query {qid}") for qid in range(NUM_QUERIES)} + queries = { + qid: SimpleTextTopicRecord.from_text(f"Query {qid}") + for qid in range(NUM_QUERIES) + } # Retrieve query per query for qid, query in queries.items(): @@ -100,8 +106,8 @@ def test_fullretrieverescorer(tmp_path: Path): results.sort(reverse=True) expected.sort(reverse=True) - assert [d.document.get_id() for d in expected] == [ - d.document.get_id() for d in results + assert [d.document[IDItem].id for d in expected] == [ + d.document[IDItem].id for d in results ], "Document IDs do not match" assert [d.score for d in expected] == [ d.score for d in results diff --git a/src/xpmir/test/utils/utils.py b/src/xpmir/test/utils/utils.py index 7510878e..85c7af93 100644 --- a/src/xpmir/test/utils/utils.py +++ b/src/xpmir/test/utils/utils.py @@ -1,23 +1,21 @@ from collections import OrderedDict, defaultdict -from typing import ClassVar, Dict, Iterator, List, Tuple +from typing import ClassVar, Dict, Iterator, List, Tuple, Any import torch -from attrs import define import numpy as np +from datamaestro.record import recordtypes from datamaestro_text.data.ir import ( DocumentStore, GenericDocumentRecord, - IDItem, - SimpleTextItem, InternalIDItem, ) from experimaestro import Param -from xpmir.text.encoders import TextEncoder +from xpmir.text.encoders import TextEncoder, RepresentationOutput -@define +@recordtypes(InternalIDItem) class GenericDocumentWithIDRecord(GenericDocumentRecord): - internal_docid: int + ... class SampleDocumentStore(DocumentStore): @@ -29,9 +27,9 @@ def __post_init__(self): self.documents = OrderedDict( ( str(ix), - GenericDocumentWithIDRecord( - IDItem(str(ix)), - SimpleTextItem(f"Document {ix}"), + GenericDocumentWithIDRecord.create( + str(ix), + f"Document {ix}", InternalIDItem(ix), ), ) @@ -80,6 +78,11 @@ def __call__(self) -> torch.Tensor: return x +def check_str(x: Any): + assert isinstance(x, str) + return x + + class SparseRandomTextEncoder(TextEncoder): # A default dict to always return the same embeddings MAPS: ClassVar[Dict[Tuple[int, float], Dict[str, torch.Tensor]]] = {} @@ -104,4 +107,6 @@ def dimension(self): def forward(self, texts: List[str]) -> torch.Tensor: """Returns a matrix encoding the provided texts""" - return torch.cat([self.map[text].unsqueeze(0) for text in texts]) + return RepresentationOutput( + torch.cat([self.map[check_str(text)].unsqueeze(0) for text in texts]) + ) diff --git a/src/xpmir/text/adapters.py b/src/xpmir/text/adapters.py index 8bfae00a..342d0419 100644 --- a/src/xpmir/text/adapters.py +++ b/src/xpmir/text/adapters.py @@ -1,12 +1,12 @@ from typing import List from experimaestro import Param -from .encoders import TokenizedTextEncoderBase, InputType, EncoderOutput +from .encoders import TokenizedTextEncoderBase, InputType, RepresentationOutput -class MeanTextEncoder(TokenizedTextEncoderBase[InputType, EncoderOutput]): +class MeanTextEncoder(TokenizedTextEncoderBase[InputType, RepresentationOutput]): """Returns the mean of the word embeddings""" - encoder: Param[TokenizedTextEncoderBase[InputType, EncoderOutput]] + encoder: Param[TokenizedTextEncoderBase[InputType, RepresentationOutput]] def __initialize__(self, options): self.encoder.__initialize__(options) @@ -18,7 +18,7 @@ def static(self): def dimension(self): return self.encoder.dimension() - def forward(self, texts: List[InputType], options=None) -> EncoderOutput: + def forward(self, texts: List[InputType], options=None) -> RepresentationOutput: emb_texts = self.encoder(texts, options=options) # Computes the mean over the time dimension (vocab output is batch x time x dim) return emb_texts.value.mean(1)