Skip to content

Commit

Permalink
Merge branch 'features/module_init'
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Dec 6, 2023
2 parents 2cb8911 + 560883a commit c9e4208
Show file tree
Hide file tree
Showing 26 changed files with 148 additions and 118 deletions.
2 changes: 2 additions & 0 deletions docs/source/learning/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ Listeners can be used to monitor the learning process

.. autoxpmconfig:: xpmir.learning.learner.LearnerListener
:members: __call__

.. autoxpmconfig:: xpmir.learning.context.ValidationHook
3 changes: 2 additions & 1 deletion docs/source/learning/optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Optimizers


.. autoxpmconfig:: xpmir.learning.optim.Optimizer
.. autoxpmconfig:: xpmir.learning.optim.SGD
.. autoxpmconfig:: xpmir.learning.optim.Adam
.. autoxpmconfig:: xpmir.learning.optim.AdamW

Expand All @@ -45,6 +46,7 @@ The classes below allow to select a subset of parameters.
.. autoxpmconfig:: xpmir.learning.parameters.InverseParametersIterator
.. autoxpmconfig:: xpmir.learning.parameters.ParametersIterator
.. autoxpmconfig:: xpmir.learning.parameters.SubParametersIterator
.. autoxpmconfig:: xpmir.learning.parameters.RegexParametersIterator

Freezing
********
Expand All @@ -59,7 +61,6 @@ Loading
.. autoxpmconfig:: xpmir.learning.parameters.PartialModuleLoader
.. autoxpmconfig:: xpmir.learning.parameters.SubModuleLoader


Batching
--------

Expand Down
3 changes: 3 additions & 0 deletions docs/source/letor/pairwise.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Trainer

.. autoxpmconfig:: xpmir.letor.trainers.pairwise.PairwiseTrainer
.. autoxpmconfig:: xpmir.letor.trainers.pairwise.DuoPairwiseTrainer
.. autoxpmconfig:: xpmir.letor.trainers.generative.GenerativeTrainer


Losses
------
Expand All @@ -43,5 +45,6 @@ Pairwise
.. autoxpmconfig:: PairwiseSampleDatasetFromTSV
.. autoxpmconfig:: PairwiseSamplerFromTSV
.. autoxpmconfig:: ModelBasedHardNegativeSampler
.. autoxpmconfig:: TripletBasedInBatchNegativeSampler

.. autoxpmconfig:: xpmir.letor.samplers.hydrators.PairwiseTransformAdapter
10 changes: 10 additions & 0 deletions docs/source/neural.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ Sparse Models
.. autoxpmconfig:: xpmir.neural.splade.MaxAggregation
.. autoxpmconfig:: xpmir.neural.splade.SumAggregation

Generative Models
=================

.. autoxpmconfig:: xpmir.neural.generative.IdentifierGenerator

.. autoxpmconfig:: xpmir.neural.generative.hf.LoadFromT5

.. autoxpmconfig:: xpmir.neural.generative.hf.T5IdentifierGenerator


From Huggingface
================

Expand Down
3 changes: 2 additions & 1 deletion src/xpmir/index/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datamaestro_text.data.ir import DocumentStore
from xpmir.rankers import Retriever, ScoredDocument
from xpmir.learning.batchers import Batcher
from xpmir.learning import ModuleInitMode
from xpmir.text.encoders import TextEncoder
from xpmir.letor import (
Device,
Expand Down Expand Up @@ -130,7 +131,7 @@ def _execute(self, device_information: DeviceInformation):
step_iter = tqdm(total=2, desc="Building the FAISS index")

# Initializations
self.encoder.initialize()
self.encoder.initialize(ModuleInitMode.DEFAULT.to_options())
index = faiss.index_factory(
self.encoder.dimension, self.indexspec, faiss.METRIC_INNER_PRODUCT
)
Expand Down
5 changes: 3 additions & 2 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Constant,
)
from datamaestro_text.data.ir import Document, DocumentStore
from xpmir.learning import ModuleInitMode
from xpmir.learning.batchers import Batcher
from xpmir.utils.utils import batchiter, easylog
from xpmir.letor import Device, DEFAULT_DEVICE
Expand Down Expand Up @@ -71,7 +72,7 @@ class SparseRetriever(Retriever):

def initialize(self):
super().initialize()
self.encoder.initialize()
self.encoder.initialize(ModuleInitMode.RANDOM.to_options(None))
self.index.initialize(self.in_memory)

def retrieve_all(self, queries: Dict[str, str]) -> Dict[str, List[ScoredDocument]]:
Expand Down Expand Up @@ -173,7 +174,7 @@ def execute(self):
f"Load the encoder and transfer to the target device {self.device.value}"
)

self.encoder.initialize()
self.encoder.initialize(ModuleInitMode.RANDOM.to_options(None))
self.encoder.to(self.device.value).eval()

batcher = self.batcher.initialize(self.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa: F401
from .base import Random, Sampler
from .optim import Module
from .optim import Module, ModuleInitMode, ModuleInitOptions
8 changes: 4 additions & 4 deletions src/xpmir/learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from xpmir.context import Hook, InitializationHook
from xpmir.utils.utils import EasyLogger, easylog, foreach
from xpmir.learning.devices import DEFAULT_DEVICE, Device, DeviceInformation
from xpmir.learning import Random
from xpmir.learning import Random, ModuleInitMode
from xpmir.learning.trainers import Trainer
from xpmir.learning.context import (
StepTrainingHook,
Expand Down Expand Up @@ -223,15 +223,15 @@ def device_execute(self, device_information: DeviceInformation):
torch.cuda.manual_seed_all(seed)

# Initialize the scorer and trainer
self.logger.info("Scorer initialization")
self.model.initialize()
self.logger.info("model initialization")
self.model.initialize(ModuleInitMode.DEFAULT.to_options(self.random.state))

# Initialize the context and the listeners
self.trainer.initialize(self.random.state, self.context)
for listener in self.listeners:
listener.initialize(self, self.context)

self.logger.info("Moving to device %s", device_information.device)
self.logger.info("Moving model to device %s", device_information.device)
self.model.to(device_information.device)
self.trainer.to(device_information.device)
num_training_steps = self.max_epochs * self.steps_per_epoch
Expand Down
44 changes: 40 additions & 4 deletions src/xpmir/learning/optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from enum import Enum
import threading
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
from pathlib import Path
import numpy as np
import torch
import logging
import re
Expand Down Expand Up @@ -83,13 +86,46 @@ def __call__(self, parameters):
)


class ModuleInitMode(Enum):
"""Initialization mode"""

#: Default initialization (i.e. can load default parameters or initialize randomly)
DEFAULT = 0

#: No parameter initialization (just initialize the structure of the model)
NONE = 1

#: Random initialization (initialize the structure, then use a the random
#: number generator to initialize the values)
RANDOM = 2

def to_options(self, random: Optional[np.random.RandomState] = None):
return ModuleInitOptions(self, random)


@dataclass
class ModuleInitOptions:
#: Initialization mode
mode: ModuleInitMode

#: Random generator (only defined when mode is RANDOM)
random: Optional[np.random.RandomState] = None


class Module(Config, Initializable, torch.nn.Module):
"""A module contains parameters"""

def __init__(self):
Initializable.__init__(self)
torch.nn.Module.__init__(self)

def __initialize__(self, options: ModuleInitOptions):
"""Initialize a module
:param options: The initialization options
"""
pass

def __call__(self, *args, **kwargs):
return torch.nn.Module.__call__(self, *args, **kwargs)

Expand All @@ -107,9 +143,9 @@ def __post_init__(self):
for ix, sub_module in enumerate(self.sub_modules):
self.add_module(str(ix), sub_module)

def __initialize__(self, *args, **kwargs):
def __initialize__(self, options: ModuleInitOptions):
for module in self.sub_modules:
module.initialize(*args, **kwargs)
module.initialize(options)

def __call__(self, *args, **kwargs):
raise AssertionError("This module cannot be used as such")
Expand All @@ -122,7 +158,7 @@ class ModuleLoader(PathSerializationLWTask):
def execute(self):
"""Loads the model from disk using the given serialization path"""
logging.info("Loading model from disk: %s", self.path)
self.value.initialize(None)
self.value.initialize(ModuleInitMode.NONE.to_options())
data = torch.load(self.path)
self.value.load_state_dict(data)

Expand Down Expand Up @@ -235,7 +271,7 @@ def initialize(
try:
next(module.parameters())
except StopIteration:
raise RuntimeError("No parameters to optimize in the module")
raise RuntimeError(f"No parameters to optimize in the module {module}")

filter = DuplicateParameterFilter()
for param_optimizer in param_optimizers:
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/learning/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Iterator, NamedTuple
from experimaestro import Param, Config, PathSerializationLWTask
import torch
from xpmir.learning.optim import Module, ModuleLoader
from xpmir.learning.optim import Module, ModuleLoader, ModuleInitMode
import logging

logger = logging.getLogger("xpmir.learning")
Expand Down Expand Up @@ -152,7 +152,7 @@ class PartialModuleLoader(PathSerializationLWTask):

def execute(self):
"""Combine the model in the selectors"""
self.value.initialize(None)
self.value.initialize(ModuleInitMode.NONE.to_options())
data = torch.load(self.path)
logger.info(
"(partial module loader) Loading parameters from %s into %s",
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/neural/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def __validate__(self):
assert not self.querytoken, "Not implemented"
assert not self.doctoken, "Not implemented"

def _initialize(self, random):
super()._initialize(random)
def __initialize__(self, options):
super().__initialize__(options)

self.linear = nn.Linear(self.vocab.dim(), self.linear_dim, bias=False)

Expand Down
10 changes: 6 additions & 4 deletions src/xpmir/neural/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"

def _initialize(self, random):
self.encoder.initialize()
def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)

def forward(self, inputs: BaseRecords, info: TrainerContext = None):
Expand Down Expand Up @@ -65,8 +66,9 @@ class DuoCrossScorer(DuoLearnableScorer, DistributableModel):
def __validate__(self):
assert not self.encoder.static(), "The vocabulary should be learnable"

def _initialize(self, random):
self.encoder.initialize()
def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)

def forward(self, inputs: PairwiseRecords, info: TrainerContext = None):
Expand Down
6 changes: 3 additions & 3 deletions src/xpmir/neural/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"

def _initialize(self, random):
self.encoder.initialize()
def __initialize__(self, options):
self.encoder.initialize(options)
if self.query_encoder:
self.query_encoder.initialize()
self.query_encoder.initialize(options)

def score_product(self, queries, documents, info: Optional[TrainerContext]):
return queries @ documents.T
Expand Down
3 changes: 0 additions & 3 deletions src/xpmir/neural/generative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def step(self, token_ids: torch.LongTensor) -> torch.Tensor:
class IdentifierGenerator(Module):
"""Models that generate an identifier given a document or a query"""

def __initialize__(self):
pass

@abstractmethod
def stepwise_iterator(self) -> StepwiseGenerator:
pass
11 changes: 7 additions & 4 deletions src/xpmir/neural/generative/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import nn
import numpy as np
from xpmir.learning import ModuleInitOptions, ModuleInitMode
from xpmir.letor.records import TokenizedTexts
from xpmir.distributed import DistributableModel
from . import IdentifierGenerator, StepwiseGenerator
Expand Down Expand Up @@ -99,13 +99,16 @@ class T5IdentifierGenerator(IdentifierGenerator, DistributableModel):
def stepwise_iterator(self) -> StepwiseGenerator:
return T5StepwiseGenerator(self)

def __initialize__(self, random: Optional[np.random.RandomState] = None):
super().__initialize__()
def __initialize__(self, options: ModuleInitOptions):
assert options.mode != ModuleInitMode.RANDOM, "Random mode not handled (yet)"

super().__initialize__(options)

# Easy and hacky way to get the device
self._dummy_params = nn.Parameter(torch.Tensor())
self.config = AutoConfig.from_pretrained(self.hf_id)

self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id, use_fast=True)
self.config = AutoConfig.from_pretrained(self.hf_id)

self.model = CustomOutputT5(self.config, self.decoder_outdim)
self.pad_token_id = self.model.config.pad_token_id
Expand Down
3 changes: 0 additions & 3 deletions src/xpmir/neural/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def __post_init__(self):
)
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id)

def _initialize(self, random):
pass

def batch_tokenize(
self,
input_records: BaseRecords,
Expand Down
5 changes: 2 additions & 3 deletions src/xpmir/neural/interaction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ class InteractionScorer(LearnableScorer):
qlen: Param[int] = 20
dlen: Param[int] = 2000

def _initialize(self, random):
self.random = random
self.vocab.initialize()
def __initialize__(self, options):
self.vocab.initialize(options)

def __validate__(self):
assert (
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/neural/interaction/drmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def __validate__(self):
self.index is not None
), "index must be provided if using IDF"

def _initialize(self, random):
def __initialize__(self, options):
if not self.vocab.static():
self.logger.warning(
"In most cases, using vocab.train=True will not have an effect on DRMM "
"because the histogram is not differentiable. An exception might be if "
"the gradient is proped back by another means, e.g. BERT [CLS] token."
)
super()._initialize(random)
super().__initialize__(options)
self.simmat = modules.InteractionMatrix(self.vocab.pad_tokenid)
channels = self.vocab.emb_views()
self.hidden_1 = nn.Linear(self.hist.nbins * channels, self.hidden)
Expand Down
5 changes: 3 additions & 2 deletions src/xpmir/neural/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
import torch.nn as nn
import torch
from xpmir.learning import ModuleInitOptions
from xpmir.distributed import DistributableModel
from xpmir.text.huggingface import (
OneHotHuggingFaceEncoder,
Expand Down Expand Up @@ -89,8 +90,8 @@ class SpladeTextEncoder(TextEncoder, DistributableModel):
maxlen: Param[Optional[int]] = None
"""Max length for texts"""

def __initialize__(self, random=None):
self.encoder.initialize()
def __initialize__(self, options: ModuleInitOptions):
self.encoder.initialize(options)
self.model = SpladeTextEncoderModel(self.encoder, self.aggregation)

def forward(self, texts: List[str]) -> torch.Tensor:
Expand Down
Loading

0 comments on commit c9e4208

Please sign in to comment.