-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from experimaestro/features/generative
Basis for generative retrieval (among other things)
- Loading branch information
Showing
9 changed files
with
428 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,5 @@ repos: | |
hooks: | ||
- id: flake8 | ||
additional_dependencies: | ||
- flake8-black | ||
- flake8-print | ||
- flake8-fixme |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from torch import nn | ||
import numpy as np | ||
from experimaestro import Param | ||
import logging | ||
|
||
from xpmir.letor.samplers import PairwiseSampler | ||
from xpmir.letor.records import BaseRecords, PairwiseRecords | ||
from xpmir.letor.trainers import TrainerContext, LossTrainer | ||
from xpmir.learning.context import Loss | ||
from xpmir.utils.iter import MultiprocessSerializableIterator | ||
from xpmir.utils.utils import foreach, easylog | ||
|
||
logger = easylog() | ||
|
||
|
||
class PairwiseGenerativeLoss(nn.Module): | ||
"""Generic loss for generative models""" | ||
|
||
NAME = "?" | ||
|
||
weight: Param[float] = 1.0 | ||
"""The weight :math:`w` with which the loss is multiplied (useful when | ||
combining with other ones)""" | ||
|
||
def compute(self, records, context): | ||
pass | ||
|
||
def process(self, records: BaseRecords, context: TrainerContext): | ||
value = self.compute(records, context) # tensor shape [bs] | ||
if logger.isEnabledFor(logging.DEBUG): | ||
logger.debug(f"Loss: {value}") | ||
context.add_loss(Loss(f"pair-{self.NAME}", value, self.weight)) | ||
|
||
|
||
class GenerativeTrainer(LossTrainer): | ||
loss: Param[PairwiseGenerativeLoss] | ||
|
||
sampler: Param[PairwiseSampler] | ||
"""The pairwise sampler""" | ||
|
||
def initialize(self, random: np.random.RandomState, context: TrainerContext): | ||
super().initialize(random, context) | ||
self.loss.initialize() | ||
foreach( | ||
context.hooks(PairwiseGenerativeLoss), lambda loss: loss.initialize() | ||
) # maybe later we need to change the sampling target, we can use this hook | ||
|
||
self.sampler.initialize(random) | ||
self.sampler_iter = MultiprocessSerializableIterator( | ||
self.sampler.pairwise_batch_iter(self.batch_size) | ||
) | ||
|
||
def train_batch(self, records: PairwiseRecords): | ||
# do the forward pass to get the gradient value | ||
self.loss.process(records, self.context) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from typing import List | ||
from abc import abstractmethod | ||
|
||
import torch | ||
from xpmir.learning.optim import Module | ||
from xpmir.utils.utils import easylog | ||
|
||
logger = easylog() | ||
|
||
|
||
class StepwiseGenerator: | ||
"""Utility class for generating one token at a time""" | ||
|
||
@abstractmethod | ||
def init(self, texts: List[str]) -> torch.Tensor: | ||
"""Returns the distribution over the first generated tokens (BxV) | ||
given the texts""" | ||
pass | ||
|
||
@abstractmethod | ||
def step(self, token_ids: torch.LongTensor) -> torch.Tensor: | ||
"""Returns the distribution over next tokens (BxV), given the last | ||
generates ones (B)""" | ||
pass | ||
|
||
|
||
class IdentifierGenerator(Module): | ||
"""Models that generate an identifier given a document or a query""" | ||
|
||
def __initialize__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def stepwise_iterator(self) -> StepwiseGenerator: | ||
pass |
Oops, something went wrong.