-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
improvements for the alignement-based learning
- Loading branch information
Showing
11 changed files
with
297 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,52 @@ | ||
from functools import cached_property | ||
from datamaestro.record import Record | ||
from typing import Iterator, List | ||
|
||
import numpy as np | ||
from datamaestro.record import Record | ||
from datamaestro_text.data.conversation import ( | ||
ConversationDataset, | ||
ConversationHistoryItem, | ||
EntryType, | ||
) | ||
from experimaestro import Param | ||
from experimaestro import Config, Param | ||
|
||
from xpmir.learning.base import BaseSampler | ||
from xpmir.learning.base import BaseSampler, SampleIterator | ||
from xpmir.utils.iter import RandomSerializableIterator | ||
|
||
|
||
class DatasetConversationEntrySampler(BaseSampler): | ||
"""Uses a conversation dataset and topic records entries""" | ||
|
||
dataset: Param[ConversationDataset] | ||
"""The conversation dataset""" | ||
class DatasetConversationBase(Config): | ||
datasets: Param[List[ConversationDataset]] | ||
"""The conversation datasets""" | ||
|
||
@cached_property | ||
def conversations(self): | ||
return list(self.dataset.__iter__()) | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
||
def __iter__(self) -> RandomSerializableIterator[Record]: | ||
def generator(random: np.random.RandomState): | ||
while True: | ||
# Pick a random conversation | ||
conversation_ix = random.randint(0, len(self.conversations)) | ||
conversation = self.conversations[conversation_ix] | ||
|
||
# Pick a random topic record entry | ||
def records(self): | ||
records = [] | ||
for dataset in self.datasets: | ||
for conversation in dataset.__iter__(): | ||
nodes = [ | ||
node | ||
for node in conversation | ||
if node.entry()[EntryType] == EntryType.USER_QUERY | ||
] | ||
node_ix = random.randint(len(nodes)) | ||
node = nodes[node_ix] | ||
for node in nodes: | ||
records.append( | ||
node.entry().update(ConversationHistoryItem(node.history())) | ||
) | ||
|
||
return records | ||
|
||
|
||
node = node.entry().update(ConversationHistoryItem(node.history())) | ||
class DatasetConversationIterator(SampleIterator, DatasetConversationBase): | ||
def __iter__(self) -> Iterator[Record]: | ||
yield from self.records | ||
|
||
yield node | ||
|
||
class DatasetConversationEntrySampler(BaseSampler, DatasetConversationBase): | ||
"""Uses a conversation dataset and topic records entries""" | ||
|
||
def __iter__(self) -> RandomSerializableIterator[Record]: | ||
def generator(random: np.random.RandomState): | ||
while True: | ||
yield self.records[random.randint(0, len(self.records))] | ||
|
||
return RandomSerializableIterator(self.random, generator) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
# flake8: noqa: F401 | ||
from .base import Random, Sampler | ||
from .base import Random, Sampler, SampleIterator | ||
from .optim import Module, ModuleInitMode, ModuleInitOptions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.