Skip to content

Commit

Permalink
improvements for the alignement-based learning
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed May 15, 2024
1 parent f0c2a7f commit ea3ce9d
Show file tree
Hide file tree
Showing 11 changed files with 297 additions and 56 deletions.
56 changes: 30 additions & 26 deletions src/xpmir/conversation/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,52 @@
from functools import cached_property
from datamaestro.record import Record
from typing import Iterator, List

import numpy as np
from datamaestro.record import Record
from datamaestro_text.data.conversation import (
ConversationDataset,
ConversationHistoryItem,
EntryType,
)
from experimaestro import Param
from experimaestro import Config, Param

from xpmir.learning.base import BaseSampler
from xpmir.learning.base import BaseSampler, SampleIterator
from xpmir.utils.iter import RandomSerializableIterator


class DatasetConversationEntrySampler(BaseSampler):
"""Uses a conversation dataset and topic records entries"""

dataset: Param[ConversationDataset]
"""The conversation dataset"""
class DatasetConversationBase(Config):
datasets: Param[List[ConversationDataset]]
"""The conversation datasets"""

@cached_property
def conversations(self):
return list(self.dataset.__iter__())

def __post_init__(self):
super().__post_init__()

def __iter__(self) -> RandomSerializableIterator[Record]:
def generator(random: np.random.RandomState):
while True:
# Pick a random conversation
conversation_ix = random.randint(0, len(self.conversations))
conversation = self.conversations[conversation_ix]

# Pick a random topic record entry
def records(self):
records = []
for dataset in self.datasets:
for conversation in dataset.__iter__():
nodes = [
node
for node in conversation
if node.entry()[EntryType] == EntryType.USER_QUERY
]
node_ix = random.randint(len(nodes))
node = nodes[node_ix]
for node in nodes:
records.append(
node.entry().update(ConversationHistoryItem(node.history()))
)

return records


node = node.entry().update(ConversationHistoryItem(node.history()))
class DatasetConversationIterator(SampleIterator, DatasetConversationBase):
def __iter__(self) -> Iterator[Record]:
yield from self.records

yield node

class DatasetConversationEntrySampler(BaseSampler, DatasetConversationBase):
"""Uses a conversation dataset and topic records entries"""

def __iter__(self) -> RandomSerializableIterator[Record]:
def generator(random: np.random.RandomState):
while True:
yield self.records[random.randint(0, len(self.records))]

return RandomSerializableIterator(self.random, generator)
6 changes: 5 additions & 1 deletion src/xpmir/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def evaluate_retriever(
self.add(evaluation)

# Use retriever tags
retriever_tags = tags(retriever)
retriever_tags = tags(evaluation)
if retriever_tags:
self.per_tags[retriever_tags] = evaluation

Expand All @@ -215,6 +215,10 @@ def to_dataframe(self) -> pd.DataFrame:
metrics.update(results.keys())
to_process.append((tags_dict, results))

# Sort metrics
metrics = list(metrics)
metrics.sort()

# Table header
columns = []
for tag in tags:
Expand Down
16 changes: 11 additions & 5 deletions src/xpmir/index/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,11 @@ def execute(self):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")

max_docs = (
self.documents.documentcount
if self.max_docs == 0
else min(self.max_docs, self.documents.documentcount)
)
max_docs = 0
if self.max_docs:
max_docs = min(self.max_docs, self.documents.documentcount or sys.maxsize)
logger.warning("Limited indexing to %d documents", max_docs)

iter_batches = MultiprocessIterator(
DocumentIterator(self.documents, max_docs, self.batch_size)
).detach()
Expand Down Expand Up @@ -321,6 +321,8 @@ def execute(self):
finally:
logger.info("Waiting for the index process to stop")
index_thread.join()
if not self.index_done:
raise RuntimeError("Indexing thread did not complete")

def index(
self,
Expand All @@ -331,6 +333,7 @@ def index(
:param queues: Queues are used to send tensors
"""
self.index_done = False
with tqdm(
total=max_docs,
unit="documents",
Expand Down Expand Up @@ -377,6 +380,9 @@ def index(

logger.info("Building the index")
indexer.build(self.in_memory)

logger.info("Index built")
self.index_done = True
except Empty:
logger.warning("One encoder got a problem... stopping")
raise
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# flake8: noqa: F401
from .base import Random, Sampler
from .base import Random, Sampler, SampleIterator
from .optim import Module, ModuleInitMode, ModuleInitOptions
29 changes: 27 additions & 2 deletions src/xpmir/learning/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, Generic, TypeVar
from typing import Optional, TypeVar, Sequence, Iterator, Iterable
import numpy as np
from functools import cached_property
from experimaestro import Config, Param
Expand Down Expand Up @@ -31,7 +31,32 @@ def initialize(self, random: Optional[np.random.RandomState]):
T = TypeVar("T")


class BaseSampler(Sampler, Generic[T], ABC):
class SampleIterator(Config, Iterable[T], ABC):
"""Generic class to iterate over items or batch of items"""

@abstractmethod
def __iter__() -> Iterator[T]:
pass

def __batch_iter__(self, batch_size: int) -> Iterator[Sequence[T]]:
"""Batch iterations"""
iterator = self.__iter__()
data = []
try:
while True:
data.append(next(iterator))
if len(data) == batch_size:
yield data
except StopIteration:
pass

if data:
yield data


class BaseSampler(Sampler, SampleIterator[T], ABC):
"""A serializable sampler iterator"""

@abstractmethod
def __iter__() -> SerializableIterator[T]:
pass
Expand Down
2 changes: 1 addition & 1 deletion src/xpmir/learning/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,6 @@ def add_metric(self, metric: Metric):
def scope(self, name: str):
try:
self._scope.append(name)
yield
yield self
finally:
self._scope.pop()
3 changes: 3 additions & 0 deletions src/xpmir/learning/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def _merge(self, other: "ScalarMetric"):
self.sum += other.sum
self.count += other.count

def compute(self):
return self.sum / self.count

def report(self, step: int, writer: SummaryWriter, prefix: str):
if self.count == 0:
logging.warning("Count is 0 when reporting metrics")
Expand Down
51 changes: 31 additions & 20 deletions src/xpmir/learning/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
tagspath,
Task,
PathSerializationLWTask,
experiment,
RunMode,
)
from experimaestro.scheduler import Job, Listener
from experimaestro.utils import cleanupdir
Expand Down Expand Up @@ -488,43 +490,52 @@ def job_state(self, job: Job):
class TensorboardService(WebService):
id = "tensorboard"

def __init__(self, path: Path):
def __init__(self, xp: experiment, path: Path):
super().__init__()

self.path = path
cleanupdir(self.path)
self.path.mkdir(exist_ok=True, parents=True)
logger.info("You can monitor learning with:")
logger.info("tensorboard --logdir=%s", self.path)
self.url = None
self.run_mode = xp.run_mode

if self.run_mode == RunMode.NORMAL:
cleanupdir(self.path)
self.path.mkdir(exist_ok=True, parents=True)
logger.info("You can monitor learning with:")
logger.info("tensorboard --logdir=%s", self.path)

def add(self, task: Task, path: Path):
# Wait until config has started
if job := task.__xpm__.job:
if job.scheduler is not None:
tag_path = tagspath(task)
if tag_path:
job.scheduler.addlistener(
TensorboardServiceListener(self.path / tag_path, path)
)
if self.run_mode == RunMode.NORMAL:
if job := task.__xpm__.job:
if job.scheduler is not None:
tag_path = tagspath(task)
if tag_path:
job.scheduler.addlistener(
TensorboardServiceListener(self.path / tag_path, path)
)
else:
logger.error(
"The task is not associated with tags: "
"cannot link to tensorboard data"
)
else:
logger.error(
"The task is not associated with tags: "
"cannot link to tensorboard data"
)
logger.debug("No scheduler: not adding the tensorboard data")
else:
logger.debug("No scheduler: not adding the tensorboard data")
else:
logger.error("Task was not started: cannot link to tensorboard job path")
logger.error(
"Task was not started: cannot link to tensorboard job path"
)

def description(self):
return "Tensorboard service"

def close(self):
if self.server:
if self.server and self.run_mode == RunMode.NORMAL:
self.server.shutdown()

def _serve(self, running: threading.Event):
if self.run_mode != RunMode.NORMAL:
return

import tensorboard as tb

try:
Expand Down
Loading

0 comments on commit ea3ce9d

Please sign in to comment.