Skip to content

Commit

Permalink
fix all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Feb 27, 2024
1 parent ccc529e commit df5c7c1
Show file tree
Hide file tree
Showing 35 changed files with 186 additions and 138 deletions.
19 changes: 19 additions & 0 deletions docs/source/conversation/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
Conversation
============


Learning
--------

.. autoxpmconfig:: xpmir.conversation.learning.DatasetConversationEntrySampler
.. autoxpmconfig:: xpmir.conversation.learning.reformulation.ConversationRepresentationEncoder
.. autoxpmconfig:: xpmir.conversation.learning.reformulation.DecontextualizedQueryConverter




CoSPLADE
--------

.. autoxpmconfig:: xpmir.conversation.models.cosplade.AsymetricMSEContextualizedRepresentationLoss
.. autoxpmconfig:: xpmir.conversation.models.cosplade.CoSPLADE
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ Table of Contents
evaluation
learning/index
letor/index
conversation/index
neural
hooks
text/index
misc
experiments
papers/index
pretrained
Expand Down
7 changes: 7 additions & 0 deletions docs/source/learning/optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,14 @@ Optimizers
.. autoxpmconfig:: xpmir.learning.optim.RegexParameterFilter

.. autoxpmconfig:: xpmir.learning.optim.OptimizationHook

Hooks
*****

.. autoxpmconfig:: xpmir.learning.optim.GradientHook
.. autoxpmconfig:: xpmir.learning.optim.GradientClippingHook
.. autoxpmconfig:: xpmir.learning.optim.GradientLogHook


Parameters
----------
Expand Down Expand Up @@ -94,4 +100,5 @@ Base classes

.. autoxpmconfig:: xpmir.learning.base.Random
.. autoxpmconfig:: xpmir.learning.base.Sampler
.. autoxpmconfig:: xpmir.learning.base.BaseSampler
.. autoxpmconfig:: xpmir.learning.trainers.Trainer
6 changes: 6 additions & 0 deletions docs/source/letor/alignment.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Alignment
*********

.. autoxpmconfig:: xpmir.letor.trainers.alignment.AlignmentLoss
.. autoxpmconfig:: xpmir.letor.trainers.alignment.AlignmentTrainer
.. autoxpmconfig:: xpmir.letor.trainers.alignment.MSEAlignmentLoss
2 changes: 2 additions & 0 deletions docs/source/letor/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Learning to rank
generative
mlm
generation
alignment


Learning to rank is handled by various classes. Some are located
Expand Down Expand Up @@ -85,3 +86,4 @@ Adapters
.. autoxpmconfig:: xpmir.letor.samplers.hydrators.SampleTransform
.. autoxpmconfig:: xpmir.letor.samplers.hydrators.SampleHydrator
.. autoxpmconfig:: xpmir.letor.samplers.hydrators.SamplePrefixAdding
.. autoxpmconfig:: xpmir.letor.samplers.hydrators.SampleTransformList
1 change: 1 addition & 0 deletions docs/source/misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.. autoxpmconfig:: xpmir.utils.convert.Converter
1 change: 1 addition & 0 deletions docs/source/text/huggingface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Tokenizers

.. autoxpmconfig:: xpmir.text.huggingface.tokenizers.HFListTokenizer
.. autoxpmconfig:: xpmir.text.huggingface.tokenizers.HFStringTokenizer
.. autoxpmconfig:: xpmir.text.huggingface.tokenizers.HFTokenizerAdapter

Encoders
--------
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

experimaestro>=1.5.0
datamaestro>=1.0.0
datamaestro_text>=2024.2.27.1
datamaestro_text>=2024.2.27.2
ir_datasets
docstring_parser
xpmir_rust == 0.20.*
Expand Down
6 changes: 3 additions & 3 deletions src/xpmir/conversation/models/cosplade.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def forward(self, records: List[TopicConversationRecord]):
for ix, c_record in enumerate(records):
# Adds q_n, q_1, ..., q_{n-1}
queries.append(
[c_record[TextItem].get_text()]
[c_record[TextItem].text]
+ [
entry[TextItem].get_text()
entry[TextItem].text
for entry in c_record[ConversationHistoryItem].history
if isinstance(entry, TopicRecord)
]
Expand All @@ -105,7 +105,7 @@ def forward(self, records: List[TopicConversationRecord]):
):
if isinstance(item, TopicRecord) and answer is not None:
query_answer_pairs.append(
(item[TextItem].get_text(), answer[AnswerEntry].answer)
(item[TextItem].text, answer[AnswerEntry].answer)
)
pair_origins.append(ix)
elif isinstance(item, AnswerConversationRecord):
Expand Down
6 changes: 3 additions & 3 deletions src/xpmir/datasets/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def execute(self):
with self.topics.open("wt") as fp:
for topic in topics:
ids.add(topic[IDItem].id)
fp.write(f"""{topic[IDItem].id}\t{topic[TextItem].get_text()}\n""")
fp.write(f"""{topic[IDItem].id}\t{topic[TextItem].text}\n""")

with self.assessments.open("wt") as fp:
for qrels in self.dataset.assessments.iter():
Expand Down Expand Up @@ -367,7 +367,7 @@ def execute(self):
logger.warning(
"Skipping topic %s [%s], (no assessment)",
topic[IDItem].id,
topic[TextItem].get_text(),
topic[TextItem].text,
)
continue

Expand All @@ -391,7 +391,7 @@ def execute(self):
# don't need to worry about the threshold here
for retriever in self.retrievers:
docids.update(
sd.document.get_id() for sd in retriever.retrieve(topic.text)
sd.document[IDItem].id for sd in retriever.retrieve(topic.text)
)

# Write the document IDs
Expand Down
6 changes: 3 additions & 3 deletions src/xpmir/documents/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from experimaestro import Param, Config
import torch
import numpy as np
from datamaestro_text.data.ir import DocumentStore
from datamaestro_text.data.ir import DocumentStore, TextItem
from datamaestro_text.data.ir.base import (
SimpleTextTopicRecord,
SimpleTextDocumentRecord,
Expand Down Expand Up @@ -139,11 +139,11 @@ def iter(random: np.random.RandomState):

while True:
record_pos_qry = next(iter)
text_pos_qry = record_pos_qry.text
text_pos_qry = record_pos_qry[TextItem].text
spans_pos_qry = self.get_text_span(text_pos_qry, random)

record_neg = next(iter)
text_neg = record_neg.text
text_neg = record_neg[TextItem].text
spans_neg = self.get_text_span(text_neg, random)

if not (spans_pos_qry and spans_neg):
Expand Down
5 changes: 1 addition & 4 deletions src/xpmir/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ def print_line(fp, measure, scope, value):
def get_run(retriever: Retriever, dataset: Adhoc):
"""Returns the scored documents for each topic in a dataset"""
results = retriever.retrieve_all(
{
topic[IDItem].id: topic[TextItem].get_text()
for topic in dataset.topics.iter()
}
{topic[IDItem].id: topic[TextItem].text for topic in dataset.topics.iter()}
)
return {
qid: {sd.document[IDItem].id: sd.score for sd in scoredocs}
Expand Down
10 changes: 6 additions & 4 deletions src/xpmir/index/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
from experimaestro import Annotated, Meta, Task, pathgenerator, Param, tqdm
import logging
from datamaestro_text.data.ir import DocumentStore
from datamaestro_text.data.ir import DocumentStore, TextItem
from xpmir.rankers import Retriever, ScoredDocument
from xpmir.learning.batchers import Batcher
from xpmir.learning import ModuleInitMode
Expand Down Expand Up @@ -175,7 +175,9 @@ def batch_encoder(doc_iter: Iterator[str]):
with torch.no_grad():
for batch in batchiter(self.batchsize, doc_iter):
batcher.process(
[document.text for document in batch], self.index_documents, index
[document[TextItem].text for document in batch],
self.index_documents,
index,
)

logging.info("Writing FAISS index (%d documents)", index.ntotal)
Expand All @@ -192,7 +194,7 @@ def encode(self, batch: List[str], data: List):
data.append(x)

def index_documents(self, batch: List[str], index):
x = self.encoder(batch)
x = self.encoder(batch).value
if self.normalize:
x /= x.norm(2, keepdim=True, dim=1)
index.add(np.ascontiguousarray(x.cpu().numpy()))
Expand Down Expand Up @@ -222,7 +224,7 @@ def retrieve(self, query: TopicRecord) -> List[ScoredDocument]:
"""Retrieves a documents, returning a list sorted by decreasing score"""
with torch.no_grad():
self.encoder.eval() # pass the model to the evaluation model
encoded_query = self.encoder([query])
encoded_query = self.encoder([query[TextItem].text]).value
if self.index.normalize:
encoded_query /= encoded_query.norm(2)

Expand Down
7 changes: 4 additions & 3 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def reducer(
progress,
):
for (key, _), vector in zip(
batch, self.encoder([text for _, text in batch]).cpu().detach().numpy()
batch,
self.encoder([text for _, text in batch]).value.cpu().detach().numpy(),
):
(ix,) = vector.nonzero()
query = {ix: float(v) for ix, v in zip(ix, vector[ix])}
Expand Down Expand Up @@ -114,7 +115,7 @@ def retrieve(self, query: TopicRecord, top_k=None) -> List[ScoredDocument]:
"""

# Build up iterators
vector = self.encoder([query])[0].cpu().detach().numpy()
vector = self.encoder([query]).value[0].cpu().detach().numpy()
(ix,) = vector.nonzero() # ix represents the position without 0 in the vector
query = {
ix: float(v) for ix, v in zip(ix, vector[ix])
Expand Down Expand Up @@ -215,7 +216,7 @@ def execute(self):
def encode_documents(self, batch: List[Tuple[int, DocumentRecord]]):
# Assumes for now dense vectors
vectors = (
self.encoder([d[TextItem].get_text() for _, d in batch]).value.cpu().numpy()
self.encoder([d[TextItem].text for _, d in batch]).value.cpu().numpy()
) # bs * vocab
for vector, (docid, _) in zip(vectors, batch):
(nonzero_ix,) = vector.nonzero()
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/interfaces/anserini.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _generator(out):
json.dump(
{
"id": document[IDItem].id,
"contents": document[TextItem].get_text(),
"contents": document[TextItem].text,
},
out,
)
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/letor/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,6 @@ def from_texts(
def to_texts(self) -> List[str]:
texts = []
for doc in self.documents:
texts.append(doc.document[TextItem].get_text())
texts.append(doc.document[TextItem].text)

return texts
28 changes: 15 additions & 13 deletions src/xpmir/letor/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def document(self, doc_id):
return self._store.document_ext(doc_id)

def document_text(self, doc_id):
return self.document(doc_id).get_text()
return self.document(doc_id).text

@cache("run")
def _itertopics(
Expand Down Expand Up @@ -175,11 +175,11 @@ def _itertopics(
# Retrieve documents
skipped = 0
for query in tqdm(queries):
qassessments = assessments.get(query.get_id(), None)
qassessments = assessments.get(query[IDItem].id, None)
if not qassessments:
skipped += 1
self.logger.warning(
"Skipping topic %s (no assessments)", query.get_id()
"Skipping topic %s (no assessments)", query[IDItem].id
)
continue

Expand All @@ -188,41 +188,43 @@ def _itertopics(
for docno, rel in qassessments.items():
if rel > 0:
fp.write(
f"{query.get_text() if not positives else ''}"
f"{query.text if not positives else ''}"
f"\t{docno}\t0.\t{rel}\n"
)
positives.append((docno, rel, 0))

if not positives:
self.logger.warning(
"Skipping topic %s (no relevant documents)", query.get_id()
"Skipping topic %s (no relevant documents)",
query[IDItem].id,
)
skipped += 1
continue

scoreddocuments: List[ScoredDocument] = self.retriever.retrieve(
query.get_text()
query.text
)

negatives = []
for rank, sd in enumerate(scoreddocuments):
# Get the assessment (assumes not relevant)
rel = qassessments.get(sd.document.get_id(), 0)
rel = qassessments.get(sd.document[IDItem].id, 0)
if rel > 0:
continue

negatives.append((sd.document.get_id(), rel, sd.score))
fp.write(f"\t{sd.document.get_id()}\t{sd.score}\t{rel}\n")
negatives.append((sd.document[IDItem].id, rel, sd.score))
fp.write(f"\t{sd.document[IDItem].id}\t{sd.score}\t{rel}\n")

if not negatives:
self.logger.warning(
"Skipping topic %s (no negatives documents)", query.get_id()
"Skipping topic %s (no negatives documents)",
query[IDItem].id,
)
skipped += 1
continue

assert len(positives) > 0 and len(negatives) > 0
yield query.get_text(), positives, negatives
yield query.text, positives, negatives

# Finally, move the cache file in place...
self.logger.info(
Expand Down Expand Up @@ -332,7 +334,7 @@ def sample(self, samples: List[Tuple[str, int, float]]):
while text is None:
docid, rel, score = samples[self.random.randint(0, len(samples))]
document = self.document(docid).add(ScoredItem(score))
text = document[TextItem].get_text()
text = document[TextItem].text
return document

def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]:
Expand Down Expand Up @@ -616,7 +618,7 @@ def execute(self):
positives = []
negatives = []
scoreddocuments: List[ScoredDocument] = self.retriever.retrieve(
query.get_text()
query.text
)

for rank, sd in enumerate(scoreddocuments):
Expand Down
Loading

0 comments on commit df5c7c1

Please sign in to comment.