Skip to content

Commit

Permalink
Merge pull request #27 from experimaestro/features/condenser_pretraining
Browse files Browse the repository at this point in the history
Features/co-condenser pretraining
  • Loading branch information
bpiwowar authored Dec 6, 2023
2 parents 3b37201 + f3a571c commit 2cb8911
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 30 deletions.
88 changes: 59 additions & 29 deletions src/xpmir/documents/samplers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Optional, Tuple, Iterator, Any
from experimaestro import Config, Param
from experimaestro import Param, Config
import torch
import numpy as np
from datamaestro_text.data.ir import DocumentStore
from datamaestro_text.data.ir.base import TextTopic, TextDocument
from xpmir.letor import Random
from xpmir.letor.records import Document, ProductRecords, Query
from xpmir.letor.samplers import BatchwiseSampler
from xpmir.letor.records import Document, PairwiseRecord, ProductRecords, Query
from xpmir.letor.samplers import BatchwiseSampler, PairwiseSampler
from xpmir.utils.iter import RandomSerializableIterator, SerializableIterator


Expand Down Expand Up @@ -78,7 +79,7 @@ def iter(self, count) -> Iterator[str]:
yield self.documents.document(int(docid)).text


class BatchwiseRandomSpanSampler(DocumentSampler, BatchwiseSampler):
class RandomSpanSampler(DocumentSampler, BatchwiseSampler, PairwiseSampler):
"""This sampler uses positive samples coming from the same documents
and negative ones coming from others
Expand All @@ -92,6 +93,56 @@ class BatchwiseRandomSpanSampler(DocumentSampler, BatchwiseSampler):
max_spansize: Param[int] = 1000
"""Maximum span size in number of characters"""

def get_text_span(self, text, random):
# return the two spans of text
spanlen = min(self.max_spansize, len(text) // 2)

max_start1 = len(text) - spanlen * 2
start1 = random.randint(0, max_start1) if max_start1 > 0 else 0
end1 = start1 + spanlen
if start1 > 0 and text[start1 - 1] != " ":
start1 = text.find(" ", start1) + 1
if text[end1] != " ":
end1 = text.rfind(" ", 0, end1)

max_start2 = len(text) - spanlen
start2 = random.randint(end1, max_start2) if max_start2 > end1 else end1
end2 = start2 + spanlen
if text[start2 - 1] != " ":
start2 = text.find(" ", start2) + 1
if text[end2 - 1] != " ":
end2 = text.rfind(" ", 0, end2)

# Rejet wrong samples
if end2 <= start2 or end1 <= start1:
return None

return (text[start1:end1], text[start2:end2])

def pairwise_iter(self) -> SerializableIterator[PairwiseRecord, Any]:
def iter(random: np.random.RandomState):
iter = self.documents.iter_sample(lambda m: random.randint(0, m))

while True:
record_pos_qry = next(iter)
text_pos_qry = record_pos_qry.text
spans_pos_qry = self.get_text_span(text_pos_qry, random)

record_neg = next(iter)
text_neg = record_neg.text
spans_neg = self.get_text_span(text_neg, random)

if not (spans_pos_qry and spans_neg):
continue

yield PairwiseRecord(
Query(TextTopic(spans_pos_qry[0])),
Document(TextDocument(spans_pos_qry[1])),
Document(TextDocument(spans_neg[random.randint(0, 2)])),
)

return RandomSerializableIterator(self.random, iter)

def batchwise_iter(
self, batch_size: int
) -> SerializableIterator[ProductRecords, Any]:
Expand All @@ -106,32 +157,11 @@ def iterator(random: np.random.RandomState):
while len(batch) < batch_size:
record = next(iter)
text = record.text
spanlen = min(self.max_spansize, len(text) // 2)

max_start1 = len(text) - spanlen * 2
start1 = random.randint(0, max_start1) if max_start1 > 0 else 0
end1 = start1 + spanlen
if start1 > 0 and text[start1 - 1] != " ":
start1 = text.find(" ", start1) + 1
if text[end1] != " ":
end1 = text.rfind(" ", 0, end1)

max_start2 = len(text) - spanlen
start2 = (
random.randint(end1, max_start2) if max_start2 > end1 else end1
)
end2 = start2 + spanlen
if text[start2 - 1] != " ":
start2 = text.find(" ", start2) + 1
if text[end2 - 1] != " ":
end2 = text.rfind(" ", 0, end2)

# Rejet wrong samples
if end2 <= start2 or end1 <= start1:
res = self.get_text_span(text, random)
if not res:
continue

batch.add_topics(Query(None, text[start1:end1]))
batch.add_documents(Document(None, text[start2:end2], 0))
batch.add_topics(Query(None, res[0]))
batch.add_documents(Document(None, res[1], 0))
batch.set_relevances(relevances)
yield batch

Expand Down
37 changes: 36 additions & 1 deletion src/xpmir/test/letor/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import pytest
import numpy as np
from typing import Iterator, Tuple
from experimaestro import Param
import datamaestro_text.data.ir as ir
from datamaestro_text.data.ir.base import GenericTopic, GenericDocument
from datamaestro_text.data.ir.base import GenericTopic, GenericDocument, Document
from xpmir.rankers import Retriever
from xpmir.letor.samplers import (
TrainingTriplets,
TripletBasedSampler,
ModelBasedSampler,
)
from xpmir.documents.samplers import RandomSpanSampler

# ---- Serialization

Expand Down Expand Up @@ -88,3 +91,35 @@ def test_modelbasedsampler():

for a in sampler._itertopics():
pass


class FakeDocumentStore(ir.DocumentStore):
id: Param[str] = ""

@property
def documentcount(self):
return 10

def document_int(self, internal_docid: int) -> Document:
return GenericDocument(str(internal_docid), f"D{internal_docid} " * 10)


def test_pairwise_randomspansampler():
documents = FakeDocumentStore()

sampler1 = RandomSpanSampler(documents=documents).instance()

sampler2 = RandomSpanSampler(documents=documents).instance()

random1 = np.random.RandomState(seed=0)
random2 = np.random.RandomState(seed=0)
sampler1.initialize(random1)
sampler2.initialize(random2)
iter1 = sampler1.pairwise_iter()
iter2 = sampler2.pairwise_iter()

for s1, s2, _ in zip(iter1, iter2, range(10)):
# check that they are the same with same random state
assert s1.query.topic.get_text() == s2.query.topic.get_text()
assert s1.positive.document.get_text() == s2.positive.document.get_text()
assert s1.negative.document.get_text() == s2.negative.document.get_text()

0 comments on commit 2cb8911

Please sign in to comment.