Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/module init #28

Merged
merged 7 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/learning/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ Listeners can be used to monitor the learning process

.. autoxpmconfig:: xpmir.learning.learner.LearnerListener
:members: __call__

.. autoxpmconfig:: xpmir.learning.context.ValidationHook
3 changes: 2 additions & 1 deletion docs/source/learning/optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Optimizers


.. autoxpmconfig:: xpmir.learning.optim.Optimizer
.. autoxpmconfig:: xpmir.learning.optim.SGD
.. autoxpmconfig:: xpmir.learning.optim.Adam
.. autoxpmconfig:: xpmir.learning.optim.AdamW

Expand All @@ -45,6 +46,7 @@ The classes below allow to select a subset of parameters.
.. autoxpmconfig:: xpmir.learning.parameters.InverseParametersIterator
.. autoxpmconfig:: xpmir.learning.parameters.ParametersIterator
.. autoxpmconfig:: xpmir.learning.parameters.SubParametersIterator
.. autoxpmconfig:: xpmir.learning.parameters.RegexParametersIterator

Freezing
********
Expand All @@ -59,7 +61,6 @@ Loading
.. autoxpmconfig:: xpmir.learning.parameters.PartialModuleLoader
.. autoxpmconfig:: xpmir.learning.parameters.SubModuleLoader


Batching
--------

Expand Down
3 changes: 3 additions & 0 deletions docs/source/letor/pairwise.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Trainer

.. autoxpmconfig:: xpmir.letor.trainers.pairwise.PairwiseTrainer
.. autoxpmconfig:: xpmir.letor.trainers.pairwise.DuoPairwiseTrainer
.. autoxpmconfig:: xpmir.letor.trainers.generative.GenerativeTrainer


Losses
------
Expand All @@ -43,5 +45,6 @@ Pairwise
.. autoxpmconfig:: PairwiseSampleDatasetFromTSV
.. autoxpmconfig:: PairwiseSamplerFromTSV
.. autoxpmconfig:: ModelBasedHardNegativeSampler
.. autoxpmconfig:: TripletBasedInBatchNegativeSampler

.. autoxpmconfig:: xpmir.letor.samplers.hydrators.PairwiseTransformAdapter
10 changes: 10 additions & 0 deletions docs/source/neural.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ Sparse Models
.. autoxpmconfig:: xpmir.neural.splade.MaxAggregation
.. autoxpmconfig:: xpmir.neural.splade.SumAggregation

Generative Models
=================

.. autoxpmconfig:: xpmir.neural.generative.IdentifierGenerator

.. autoxpmconfig:: xpmir.neural.generative.hf.LoadFromT5

.. autoxpmconfig:: xpmir.neural.generative.hf.T5IdentifierGenerator


From Huggingface
================

Expand Down
3 changes: 2 additions & 1 deletion src/xpmir/index/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datamaestro_text.data.ir import DocumentStore
from xpmir.rankers import Retriever, ScoredDocument
from xpmir.learning.batchers import Batcher
from xpmir.learning import ModuleInitMode
from xpmir.text.encoders import TextEncoder
from xpmir.letor import (
Device,
Expand Down Expand Up @@ -130,7 +131,7 @@ def _execute(self, device_information: DeviceInformation):
step_iter = tqdm(total=2, desc="Building the FAISS index")

# Initializations
self.encoder.initialize()
self.encoder.initialize(ModuleInitMode.DEFAULT.to_options())
index = faiss.index_factory(
self.encoder.dimension, self.indexspec, faiss.METRIC_INNER_PRODUCT
)
Expand Down
5 changes: 3 additions & 2 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Constant,
)
from datamaestro_text.data.ir import Document, DocumentStore
from xpmir.learning import ModuleInitMode
from xpmir.learning.batchers import Batcher
from xpmir.utils.utils import batchiter, easylog
from xpmir.letor import Device, DEFAULT_DEVICE
Expand Down Expand Up @@ -71,7 +72,7 @@ class SparseRetriever(Retriever):

def initialize(self):
super().initialize()
self.encoder.initialize()
self.encoder.initialize(ModuleInitMode.RANDOM.to_options(None))
self.index.initialize(self.in_memory)

def retrieve_all(self, queries: Dict[str, str]) -> Dict[str, List[ScoredDocument]]:
Expand Down Expand Up @@ -173,7 +174,7 @@ def execute(self):
f"Load the encoder and transfer to the target device {self.device.value}"
)

self.encoder.initialize()
self.encoder.initialize(ModuleInitMode.RANDOM.to_options(None))
self.encoder.to(self.device.value).eval()

batcher = self.batcher.initialize(self.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa: F401
from .base import Random, Sampler
from .optim import Module
from .optim import Module, ModuleInitMode, ModuleInitOptions
8 changes: 4 additions & 4 deletions src/xpmir/learning/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from xpmir.context import Hook, InitializationHook
from xpmir.utils.utils import EasyLogger, easylog, foreach
from xpmir.learning.devices import DEFAULT_DEVICE, Device, DeviceInformation
from xpmir.learning import Random
from xpmir.learning import Random, ModuleInitMode
from xpmir.learning.trainers import Trainer
from xpmir.learning.context import (
StepTrainingHook,
Expand Down Expand Up @@ -223,15 +223,15 @@ def device_execute(self, device_information: DeviceInformation):
torch.cuda.manual_seed_all(seed)

# Initialize the scorer and trainer
self.logger.info("Scorer initialization")
self.model.initialize()
self.logger.info("model initialization")
self.model.initialize(ModuleInitMode.DEFAULT.to_options(self.random.state))

# Initialize the context and the listeners
self.trainer.initialize(self.random.state, self.context)
for listener in self.listeners:
listener.initialize(self, self.context)

self.logger.info("Moving to device %s", device_information.device)
self.logger.info("Moving model to device %s", device_information.device)
self.model.to(device_information.device)
self.trainer.to(device_information.device)
num_training_steps = self.max_epochs * self.steps_per_epoch
Expand Down
44 changes: 40 additions & 4 deletions src/xpmir/learning/optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass
from enum import Enum
import threading
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Union
from pathlib import Path
import numpy as np
import torch
import logging
import re
Expand Down Expand Up @@ -83,13 +86,46 @@ def __call__(self, parameters):
)


class ModuleInitMode(Enum):
"""Initialization mode"""

#: Default initialization (i.e. can load default parameters or initialize randomly)
DEFAULT = 0

#: No parameter initialization (just initialize the structure of the model)
NONE = 1

#: Random initialization (initialize the structure, then use a the random
#: number generator to initialize the values)
RANDOM = 2

def to_options(self, random: Optional[np.random.RandomState] = None):
return ModuleInitOptions(self, random)


@dataclass
class ModuleInitOptions:
#: Initialization mode
mode: ModuleInitMode

#: Random generator (only defined when mode is RANDOM)
random: Optional[np.random.RandomState] = None


class Module(Config, Initializable, torch.nn.Module):
"""A module contains parameters"""

def __init__(self):
Initializable.__init__(self)
torch.nn.Module.__init__(self)

def __initialize__(self, options: ModuleInitOptions):
"""Initialize a module

:param options: The initialization options
"""
pass

def __call__(self, *args, **kwargs):
return torch.nn.Module.__call__(self, *args, **kwargs)

Expand All @@ -107,9 +143,9 @@ def __post_init__(self):
for ix, sub_module in enumerate(self.sub_modules):
self.add_module(str(ix), sub_module)

def __initialize__(self, *args, **kwargs):
def __initialize__(self, options: ModuleInitOptions):
for module in self.sub_modules:
module.initialize(*args, **kwargs)
module.initialize(options)

def __call__(self, *args, **kwargs):
raise AssertionError("This module cannot be used as such")
Expand All @@ -122,7 +158,7 @@ class ModuleLoader(PathSerializationLWTask):
def execute(self):
"""Loads the model from disk using the given serialization path"""
logging.info("Loading model from disk: %s", self.path)
self.value.initialize(None)
self.value.initialize(ModuleInitMode.NONE.to_options())
data = torch.load(self.path)
self.value.load_state_dict(data)

Expand Down Expand Up @@ -235,7 +271,7 @@ def initialize(
try:
next(module.parameters())
except StopIteration:
raise RuntimeError("No parameters to optimize in the module")
raise RuntimeError(f"No parameters to optimize in the module {module}")

filter = DuplicateParameterFilter()
for param_optimizer in param_optimizers:
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/learning/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Iterator, NamedTuple
from experimaestro import Param, Config, PathSerializationLWTask
import torch
from xpmir.learning.optim import Module, ModuleLoader
from xpmir.learning.optim import Module, ModuleLoader, ModuleInitMode
import logging

logger = logging.getLogger("xpmir.learning")
Expand Down Expand Up @@ -152,7 +152,7 @@ class PartialModuleLoader(PathSerializationLWTask):

def execute(self):
"""Combine the model in the selectors"""
self.value.initialize(None)
self.value.initialize(ModuleInitMode.NONE.to_options())
data = torch.load(self.path)
logger.info(
"(partial module loader) Loading parameters from %s into %s",
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/neural/colbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def __validate__(self):
assert not self.querytoken, "Not implemented"
assert not self.doctoken, "Not implemented"

def _initialize(self, random):
super()._initialize(random)
def __initialize__(self, options):
super().__initialize__(options)

self.linear = nn.Linear(self.vocab.dim(), self.linear_dim, bias=False)

Expand Down
10 changes: 6 additions & 4 deletions src/xpmir/neural/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"

def _initialize(self, random):
self.encoder.initialize()
def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)

def forward(self, inputs: BaseRecords, info: TrainerContext = None):
Expand Down Expand Up @@ -65,8 +66,9 @@ class DuoCrossScorer(DuoLearnableScorer, DistributableModel):
def __validate__(self):
assert not self.encoder.static(), "The vocabulary should be learnable"

def _initialize(self, random):
self.encoder.initialize()
def __initialize__(self, options):
super().__initialize__(options)
self.encoder.initialize(options)
self.classifier = torch.nn.Linear(self.encoder.dimension, 1)

def forward(self, inputs: PairwiseRecords, info: TrainerContext = None):
Expand Down
6 changes: 3 additions & 3 deletions src/xpmir/neural/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def __validate__(self):
super().__validate__()
assert not self.encoder.static(), "The vocabulary should be learnable"

def _initialize(self, random):
self.encoder.initialize()
def __initialize__(self, options):
self.encoder.initialize(options)
if self.query_encoder:
self.query_encoder.initialize()
self.query_encoder.initialize(options)

def score_product(self, queries, documents, info: Optional[TrainerContext]):
return queries @ documents.T
Expand Down
3 changes: 0 additions & 3 deletions src/xpmir/neural/generative/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ def step(self, token_ids: torch.LongTensor) -> torch.Tensor:
class IdentifierGenerator(Module):
"""Models that generate an identifier given a document or a query"""

def __initialize__(self):
pass

@abstractmethod
def stepwise_iterator(self) -> StepwiseGenerator:
pass
11 changes: 7 additions & 4 deletions src/xpmir/neural/generative/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import nn
import numpy as np
from xpmir.learning import ModuleInitOptions, ModuleInitMode
from xpmir.letor.records import TokenizedTexts
from xpmir.distributed import DistributableModel
from . import IdentifierGenerator, StepwiseGenerator
Expand Down Expand Up @@ -99,13 +99,16 @@ class T5IdentifierGenerator(IdentifierGenerator, DistributableModel):
def stepwise_iterator(self) -> StepwiseGenerator:
return T5StepwiseGenerator(self)

def __initialize__(self, random: Optional[np.random.RandomState] = None):
super().__initialize__()
def __initialize__(self, options: ModuleInitOptions):
assert options.mode != ModuleInitMode.RANDOM, "Random mode not handled (yet)"

super().__initialize__(options)

# Easy and hacky way to get the device
self._dummy_params = nn.Parameter(torch.Tensor())
self.config = AutoConfig.from_pretrained(self.hf_id)

self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id, use_fast=True)
self.config = AutoConfig.from_pretrained(self.hf_id)

self.model = CustomOutputT5(self.config, self.decoder_outdim)
self.pad_token_id = self.model.config.pad_token_id
Expand Down
3 changes: 0 additions & 3 deletions src/xpmir/neural/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ def __post_init__(self):
)
self.tokenizer = AutoTokenizer.from_pretrained(self.hf_id)

def _initialize(self, random):
pass

def batch_tokenize(
self,
input_records: BaseRecords,
Expand Down
5 changes: 2 additions & 3 deletions src/xpmir/neural/interaction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ class InteractionScorer(LearnableScorer):
qlen: Param[int] = 20
dlen: Param[int] = 2000

def _initialize(self, random):
self.random = random
self.vocab.initialize()
def __initialize__(self, options):
self.vocab.initialize(options)

def __validate__(self):
assert (
Expand Down
4 changes: 2 additions & 2 deletions src/xpmir/neural/interaction/drmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ def __validate__(self):
self.index is not None
), "index must be provided if using IDF"

def _initialize(self, random):
def __initialize__(self, options):
if not self.vocab.static():
self.logger.warning(
"In most cases, using vocab.train=True will not have an effect on DRMM "
"because the histogram is not differentiable. An exception might be if "
"the gradient is proped back by another means, e.g. BERT [CLS] token."
)
super()._initialize(random)
super().__initialize__(options)
self.simmat = modules.InteractionMatrix(self.vocab.pad_tokenid)
channels = self.vocab.emb_views()
self.hidden_1 = nn.Linear(self.hist.nbins * channels, self.hidden)
Expand Down
5 changes: 3 additions & 2 deletions src/xpmir/neural/splade.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
import torch.nn as nn
import torch
from xpmir.learning import ModuleInitOptions
from xpmir.distributed import DistributableModel
from xpmir.text.huggingface import (
OneHotHuggingFaceEncoder,
Expand Down Expand Up @@ -89,8 +90,8 @@ class SpladeTextEncoder(TextEncoder, DistributableModel):
maxlen: Param[Optional[int]] = None
"""Max length for texts"""

def __initialize__(self, random=None):
self.encoder.initialize()
def __initialize__(self, options: ModuleInitOptions):
self.encoder.initialize(options)
self.model = SpladeTextEncoderModel(self.encoder, self.aggregation)

def forward(self, texts: List[str]) -> torch.Tensor:
Expand Down
Loading
Loading