Skip to content

Commit

Permalink
Merge pull request #26 from experimaestro/features/generative
Browse files Browse the repository at this point in the history
Basis for generative retrieval (among other things)
  • Loading branch information
bpiwowar authored Dec 4, 2023
2 parents 37a87e0 + 7eb7f5f commit 3c20161
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 7 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,5 @@ repos:
hooks:
- id: flake8
additional_dependencies:
- flake8-black
- flake8-print
- flake8-fixme
10 changes: 10 additions & 0 deletions src/xpmir/learning/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ class TrainingHook(Hook):
pass


class ValidationHook(Hook):
"""Base class for all the validation hooks"""

def after(self, state: "TrainerContext"):
"""Called after a validation step"""

def before(self, state: "TrainerContext"):
"""Called before a validation step"""


class StepTrainingHook(TrainingHook):
"""Base class for hooks called at each step (before/after)"""

Expand Down
15 changes: 15 additions & 0 deletions src/xpmir/learning/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ def __call__(self, parameters) -> torch.optim.Optimizer:
raise NotImplementedError()


class SGD(Optimizer):
"""Wrapper for SGD optimizer in Pytorch"""

lr: Param[float] = 1e-5
"""Learning rate"""

weight_decay: Param[float] = 0.0
"""Weight decay (L2)"""

def __call__(self, parameters):
from torch.optim import SGD

return SGD(parameters, lr=self.lr, weight_decay=self.weight_decay)


class Adam(Optimizer):
"""Wrapper for Adam optimizer in PyTorch"""

Expand Down
23 changes: 23 additions & 0 deletions src/xpmir/learning/parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Optional
from abc import ABC, abstractmethod
import torch.nn as nn
Expand Down Expand Up @@ -48,6 +49,28 @@ def iter(self) -> Iterator[ParameterElement]:
...


class RegexParametersIterator(ParametersIterator):
"""Itertor over all the parameters which match the given regex"""

regex: Param[str]
"""The regex expression"""

model: Param[Module]
"""The model we want to select the parameters from"""

def __post_init__(self):
self._regex = re.compile(self.regex)

def should_pick(self, name: str) -> bool:
"""given the name of the str, return true if the regex expression
matches"""
return bool(self._regex.search(name))

def iter(self) -> Iterator[ParameterElement]:
for name, parameters in self.model.named_parameters():
yield ParameterElement(name, parameters, self.should_pick(name))


class InverseParametersIterator(ParametersIterator):
"""Inverse the selection of a parameter iterator"""

Expand Down
22 changes: 16 additions & 6 deletions src/xpmir/letor/learner.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import logging
import json
from pathlib import Path
from typing import Dict, Iterator
from typing import Dict, Iterator, List
from datamaestro_text.data.ir import Adhoc
from experimaestro import Param, pathgenerator, Annotated
import numpy as np
from xpmir.utils.utils import easylog
from xpmir.utils.utils import easylog, foreach
from xpmir.evaluation import evaluate
from xpmir.learning.context import (
TrainState,
TrainerContext,
)
from xpmir.learning.context import TrainState, TrainerContext, ValidationHook
from xpmir.rankers import (
Retriever,
)
Expand Down Expand Up @@ -55,6 +52,9 @@ class ValidationListener(LearnerListener):
"""Number of epochs without improvement after which we stop learning.
Should be a multiple of validation_interval or 0 (no early stopping)"""

hooks: Param[List[ValidationHook]] = []
"""The list of the hooks during the validation"""

def __validate__(self):
assert (
self.early_stop % self.validation_interval == 0
Expand Down Expand Up @@ -123,6 +123,11 @@ def __call__(self, state: TrainState):
if self.should_stop(state.epoch - 1) == LearnerListenerStatus.STOP:
return LearnerListenerStatus.STOP

foreach(
self.hooks,
lambda hook: hook.before(self.context),
)

if state.epoch % self.validation_interval == 0:
# Compute validation metrics
means, details = evaluate(
Expand Down Expand Up @@ -161,5 +166,10 @@ def __call__(self, state: TrainState):
with self.info.open("wt") as fp:
json.dump(self.top, fp)

foreach(
self.hooks,
lambda hook: hook.after(self.context),
)

# Early stopping?
return self.should_stop()
43 changes: 43 additions & 0 deletions src/xpmir/letor/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from pathlib import Path
from typing import Iterator, List, Tuple, Dict, Any
import numpy as np
Expand Down Expand Up @@ -343,6 +344,48 @@ def iter(random):
return RandomSerializableIterator(self.random, iter)


class TripletBasedInBatchNegativeSampler(PairwiseSampler):
"""An in-batch negative sampler which generate the triplets,
which use the postives of the other in batch as the negatives"""

sampler: Param[PairwiseSampler]
"""The base pairwise sampler"""

batch_size: Param[int]
"""How many triplets to be used for building the ibn"""

def initialize(self, random):
super().initialize(random)
self.sampler.initialize(random)

def pairwise_iter(self) -> SerializableIterator[PairwiseRecord]:
def iter(pair_iter):
while True:
topics = []
positives = []
for _, record in zip(range(self.batch_size), pair_iter):
topics.append(record.query)
positives.append(record.positive)
all_qry = [
topic for topic in topics for _ in range(self.batch_size - 1)
]
all_pos = [pos for pos in positives for _ in range(self.batch_size - 1)]
pos_as_neg = positives * self.batch_size
pos_index = [(self.batch_size + 1) * i for i in range(self.batch_size)]
all_neg = [
doc for i, doc in enumerate(pos_as_neg) if i not in pos_index
]

# randomize, to make the same document not gather too close
mapping = list(zip(all_qry, all_pos, all_neg))
for _ in range(30000):
random.shuffle(mapping)
for (topic, positive, negative) in mapping:
yield PairwiseRecord(topic, positive, negative)

return SerializableIteratorAdapter(self.sampler.pairwise_iter(), iter)


class PairwiseInBatchNegativesSampler(BatchwiseSampler):
"""An in-batch negative sampler constructured from a pairwise one"""

Expand Down
55 changes: 55 additions & 0 deletions src/xpmir/letor/trainers/generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from torch import nn
import numpy as np
from experimaestro import Param
import logging

from xpmir.letor.samplers import PairwiseSampler
from xpmir.letor.records import BaseRecords, PairwiseRecords
from xpmir.letor.trainers import TrainerContext, LossTrainer
from xpmir.learning.context import Loss
from xpmir.utils.iter import MultiprocessSerializableIterator
from xpmir.utils.utils import foreach, easylog

logger = easylog()


class PairwiseGenerativeLoss(nn.Module):
"""Generic loss for generative models"""

NAME = "?"

weight: Param[float] = 1.0
"""The weight :math:`w` with which the loss is multiplied (useful when
combining with other ones)"""

def compute(self, records, context):
pass

def process(self, records: BaseRecords, context: TrainerContext):
value = self.compute(records, context) # tensor shape [bs]
if logger.isEnabledFor(logging.DEBUG):
logger.debug(f"Loss: {value}")
context.add_loss(Loss(f"pair-{self.NAME}", value, self.weight))


class GenerativeTrainer(LossTrainer):
loss: Param[PairwiseGenerativeLoss]

sampler: Param[PairwiseSampler]
"""The pairwise sampler"""

def initialize(self, random: np.random.RandomState, context: TrainerContext):
super().initialize(random, context)
self.loss.initialize()
foreach(
context.hooks(PairwiseGenerativeLoss), lambda loss: loss.initialize()
) # maybe later we need to change the sampling target, we can use this hook

self.sampler.initialize(random)
self.sampler_iter = MultiprocessSerializableIterator(
self.sampler.pairwise_batch_iter(self.batch_size)
)

def train_batch(self, records: PairwiseRecords):
# do the forward pass to get the gradient value
self.loss.process(records, self.context)
35 changes: 35 additions & 0 deletions src/xpmir/neural/generative/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import List
from abc import abstractmethod

import torch
from xpmir.learning.optim import Module
from xpmir.utils.utils import easylog

logger = easylog()


class StepwiseGenerator:
"""Utility class for generating one token at a time"""

@abstractmethod
def init(self, texts: List[str]) -> torch.Tensor:
"""Returns the distribution over the first generated tokens (BxV)
given the texts"""
pass

@abstractmethod
def step(self, token_ids: torch.LongTensor) -> torch.Tensor:
"""Returns the distribution over next tokens (BxV), given the last
generates ones (B)"""
pass


class IdentifierGenerator(Module):
"""Models that generate an identifier given a document or a query"""

def __initialize__(self):
pass

@abstractmethod
def stepwise_iterator(self) -> StepwiseGenerator:
pass
Loading

0 comments on commit 3c20161

Please sign in to comment.