Skip to content

Commit

Permalink
Add BlockRegistry and PromptRegistry
Browse files Browse the repository at this point in the history
The Block and Prompt registries are how we keep track of what our
supported Block types are and which Prompts map to which teacher
models.

Co-authored-by: shivchander <[email protected]>
Co-authored-by: abhi1092 <[email protected]>
Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
3 people committed Dec 10, 2024
1 parent dee4424 commit 5e0cc23
Show file tree
Hide file tree
Showing 18 changed files with 176 additions and 35 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ exclude = [
"^src/instructlab/sdg/generate_data\\.py$",
"^src/instructlab/sdg/utils/taxonomy\\.py$",
"^src/instructlab/sdg/default_flows\\.py$",
"^src/instructlab/sdg/llmblock\\.py$",
"^src/instructlab/sdg/utilblocks\\.py$",
"^src/instructlab/sdg/blocks/llmblock\\.py$",
"^src/instructlab/sdg/blocks/utilblocks\\.py$",
]
# honor excludes by not following there through imports
follow_imports = "silent"
24 changes: 12 additions & 12 deletions src/instructlab/sdg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,19 @@
)

# Local
from .block import Block
from .filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.block import Block
from .blocks.filterblock import FilterByValueBlock, FilterByValueBlockError
from .blocks.llmblock import ConditionalLLMBlock, LLMBlock
from .blocks.utilblocks import (
CombineColumnsBlock,
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
SetToMajorityValueBlock,
)
from .generate_data import generate_data
from .llmblock import ConditionalLLMBlock, LLMBlock
from .pipeline import (
FULL_PIPELINES_PACKAGE,
SIMPLE_PIPELINES_PACKAGE,
Expand All @@ -39,14 +48,5 @@
PipelineConfigParserError,
PipelineContext,
)
from .utilblocks import (
CombineColumnsBlock,
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
SetToMajorityValueBlock,
)
from .utils import GenerateException
from .utils.taxonomy import TaxonomyReadingException
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@
from jinja2 import Template, UndefinedError
import yaml

# Local
from ..registry import BlockRegistry

logger = logging.getLogger(__name__)


# This is part of the public API.
@BlockRegistry.register("Block")
class Block(ABC):
def __init__(self, ctx, pipe, block_name: str) -> None:
self.ctx = ctx
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import openai

# Local
from ..registry import BlockRegistry
from .block import Block

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,6 +79,7 @@ def template_from_struct_and_config(struct, config):
return Template(struct.format(**filtered_config), undefined=StrictUndefined)

# This is part of the public API.
@BlockRegistry.register("LLMBlock")
# pylint: disable=dangerous-default-value
class LLMBlock(Block):
# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -273,6 +275,7 @@ def generate(self, samples: Dataset) -> Dataset:


# This is part of the public API.
@BlockRegistry.register("ConditionalLLMBlock")
class ConditionalLLMBlock(LLMBlock):
def __init__(
self,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion src/instructlab/sdg/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# pylint: disable=ungrouped-imports
from instructlab.sdg.datamixing import DataMixer, _get_question_hack, _get_response_hack
from instructlab.sdg.eval_data import generate_eval_task_data, mmlubench_pipe_init
from instructlab.sdg.llmblock import (
from instructlab.sdg.blocks.llmblock import (
DEFAULT_MAX_NUM_TOKENS,
MODEL_FAMILY_MERLINITE,
MODEL_FAMILY_MIXTRAL,
Expand Down
4 changes: 2 additions & 2 deletions src/instructlab/sdg/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from instructlab.sdg.utils import pandas

# Local
from . import filterblock, llmblock, utilblocks
from .block import Block
from .blocks import filterblock, llmblock, utilblocks
from .blocks.block import Block

logger = logging.getLogger(__name__)

Expand Down
120 changes: 120 additions & 0 deletions src/instructlab/sdg/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Standard
from typing import Dict, List, Union
import logging

# Third Party
from jinja2 import Template

logger = logging.getLogger(__name__)


class BlockRegistry:
"""Registry for block classes to avoid manual additions to block type map."""

_registry: Dict[str, type] = {}

@classmethod
def register(cls, block_name: str):
"""
Decorator to register a block class under a specified name.
:param block_name: Name under which to register the block.
"""

def decorator(block_class):
cls._registry[block_name] = block_class
logger.debug(
f"Registered block '{block_name}' with class '{block_class.__name__}'"
)
return block_class

return decorator

@classmethod
def get_registry(cls):
"""
Retrieve the current registry map of block types.
:return: Dictionary of registered block names and classes.
"""
logger.debug("Fetching the block registry map.")
return cls._registry


class PromptRegistry:
"""Registry for managing Jinja2 prompt templates."""

_registry: Dict[str, Template] = {}

@classmethod
def register(cls, name: str):
"""Decorator to register a Jinja2 template function by name.
:param name: Name of the template to register.
:return: A decorator that registers the Jinja2 template function.
"""

def decorator(func):
template_str = func()
cls._registry[name] = Template(template_str)
logger.debug(f"Registered prompt template '{name}'")
return func

return decorator

@classmethod
def get_template(cls, name: str) -> Template:
"""Retrieve a Jinja2 template by name.
:param name: Name of the template to retrieve.
:return: The Jinja2 template instance.
"""
if name not in cls._registry:
raise KeyError(f"Template '{name}' not found.")
logger.debug(f"Retrieving prompt template '{name}'")
return cls._registry[name]

@classmethod
def get_registry(cls):
"""
Retrieve the current registry map of block types.
:return: Dictionary of registered block names and classes.
"""
logger.debug("Fetching the block registry map.")
return cls._registry

@classmethod
def render_template(
cls,
name: str,
messages: Union[str, List[Dict[str, str]]],
add_generation_prompt: bool = True,
) -> str:
"""Render the template with the provided messages or query.
:param name: Name of the template to render.
:param messages: Either a single query string or a list of messages (each as a dict with 'role' and 'content').
:param add_generation_prompt: Whether to add a generation prompt at the end.
:return: The rendered prompt as a string.
"""

# Special handling for "blank" template
if name == "blank":
if not isinstance(messages, str):
raise ValueError(
"The 'blank' template can only be used with a single query string, not a list of messages."
)
return messages # Return the query as-is without templating

# Get the template
template = cls.get_template(name)

# If `messages` is a string, wrap it in a list with a default user role
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]

# Render the template with the `messages` list
return template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
12 changes: 7 additions & 5 deletions tests/test_default_pipeline_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@
from datasets import Dataset

# First Party
from instructlab.sdg.filterblock import FilterByValueBlock
from instructlab.sdg.llmblock import ConditionalLLMBlock, LLMBlock
from instructlab.sdg.pipeline import Pipeline, PipelineContext
from instructlab.sdg.utilblocks import (
from instructlab.sdg import (
CombineColumnsBlock,
ConditionalLLMBlock,
DuplicateColumnsBlock,
FilterByValueBlock,
FlattenColumnsBlock,
LLMBlock,
Pipeline,
PipelineContext,
RenameColumnsBlock,
SamplePopulatorBlock,
SelectorBlock,
Expand All @@ -35,7 +37,7 @@ def _noop_generate(self, samples):
@patch.object(RenameColumnsBlock, "generate", _noop_generate)
@patch.object(SamplePopulatorBlock, "generate", _noop_generate)
@patch.object(SelectorBlock, "generate", _noop_generate)
@patch("instructlab.sdg.llmblock.server_supports_batched", lambda c, m: True)
@patch("instructlab.sdg.blocks.llmblock.server_supports_batched", lambda c, m: True)
@patch.object(Pipeline, "_drop_duplicates", lambda self, dataset, cols: dataset)
class TestDefaultPipelineConfigs(unittest.TestCase):
def setUp(self):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_filterblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from datasets import Dataset, Features, Value

# First Party
from instructlab.sdg.filterblock import FilterByValueBlock
from instructlab.sdg.pipeline import PipelineContext
from instructlab.sdg import FilterByValueBlock


class TestFilterByValueBlock(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

# First Party
from instructlab.sdg.generate_data import _context_init, _sdg_init, generate_data
from instructlab.sdg.llmblock import LLMBlock
from instructlab.sdg.pipeline import PipelineContext
from instructlab.sdg import LLMBlock
from instructlab.sdg import PipelineContext

TEST_SYS_PROMPT = "I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."

Expand Down
9 changes: 5 additions & 4 deletions tests/test_llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from openai import InternalServerError, NotFoundError

# First Party
from src.instructlab.sdg.llmblock import ConditionalLLMBlock, LLMBlock, server_supports_batched
from src.instructlab.sdg import ConditionalLLMBlock, LLMBlock
from src.instructlab.sdg.blocks.llmblock import server_supports_batched


@patch("src.instructlab.sdg.block.Block._load_config")
@patch("src.instructlab.sdg.blocks.block.Block._load_config")
class TestLLMBlockModelPrompt(unittest.TestCase):
def setUp(self):
self.mock_ctx = MagicMock()
Expand Down Expand Up @@ -87,7 +88,7 @@ def test_model_prompt_custom(self, mock_load_config):
"model_prompt should be a non-empty string when set to None",
)

@patch("src.instructlab.sdg.block.Block._load_config")
@patch("src.instructlab.sdg.blocks.block.Block._load_config")
class TestLLMBlockOtherFunctions(unittest.TestCase):
def setUp(self):
self.mock_ctx = MagicMock()
Expand Down Expand Up @@ -186,7 +187,7 @@ def test_server_supports_batched_vllm(self):
supports_batched = server_supports_batched(self.mock_ctx.client, "my-model")
assert supports_batched

@patch("src.instructlab.sdg.block.Block._load_config")
@patch("src.instructlab.sdg.blocks.block.Block._load_config")
class TestConditionalLLMBlock(unittest.TestCase):
def setUp(self):
self.mock_ctx = MagicMock()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import pytest

# First Party
from instructlab.sdg.block import Block
from instructlab.sdg.pipeline import Pipeline, PipelineBlockError
from instructlab.sdg import Block
from instructlab.sdg import Pipeline, PipelineBlockError

## Helpers ##

Expand Down
12 changes: 12 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0

# First Party
from src.instructlab.sdg.registry import BlockRegistry

def test_block_registry():
@BlockRegistry.register("TestFooClass")
class TestFooClass:
pass
registry = BlockRegistry.get_registry()
assert registry is not None
assert registry["TestFooClass"] is TestFooClass
4 changes: 2 additions & 2 deletions tests/test_sample_populator_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datasets import Dataset, Features, Value

# First Party
from instructlab.sdg.utilblocks import SamplePopulatorBlock
from instructlab.sdg import SamplePopulatorBlock


class TestSamplePopulatorBlock(unittest.TestCase):
Expand All @@ -17,7 +17,7 @@ def setUp(self):
self.ctx.dataset_num_procs = 1
self.pipe = MagicMock()

@patch("instructlab.sdg.block.Block._load_config")
@patch("instructlab.sdg.blocks.block.Block._load_config")
def test_generate(self, mock_load_config):
def load_config(file_name):
if file_name == "coffee.yaml" or file_name == "tea.yaml":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utilblocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datasets import Dataset, Features, Value

# First Party
from src.instructlab.sdg.utilblocks import (
from src.instructlab.sdg import (
DuplicateColumnsBlock,
FlattenColumnsBlock,
RenameColumnsBlock,
Expand Down

0 comments on commit 5e0cc23

Please sign in to comment.