Skip to content

Commit

Permalink
Adapt method arg/return comments to style used elsewhere
Browse files Browse the repository at this point in the history
This cleans up the new code comments for each method to match the
style we've used elsewhere in the project for consistency's sake.

Signed-off-by: Ben Browning <[email protected]>
  • Loading branch information
bbrowning committed Dec 5, 2024
1 parent 8b0d52a commit 7abde97
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
17 changes: 12 additions & 5 deletions src/instructlab/sdg/blocks/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,13 @@ def _validate(self, prompt_template: Template, input_dict: Dict[str, Any]) -> bo
Validate the input data for this block. This method validates whether all required
variables in the Jinja template are provided in the input_dict.
:param prompt_template: The Jinja2 template object.
:param input_dict: A dictionary of input values to check against the template.
:return: True if the input data is valid (i.e., no missing variables), False otherwise.
Args:
prompt_template (Template): The Jinja2 template object.
input_dict (Dict[str, Any]): A dictionary of input values to check against
the template.
Returns:
True if the input data is valid (i.e., no missing variables), False otherwise.
"""

class Default(dict):
Expand All @@ -54,8 +58,11 @@ def _load_config(self, config_path: str) -> Union[Dict[str, Any], None]:
If the supplied configuration file is a relative path, it is assumed
to be part of this Python package.
:param config_path: The path to the configuration file.
:return: The loaded configuration.
Args:
config_path (str): The path to the configuration file.
Returns:
The loaded configuration.
"""
if not os.path.isabs(config_path):
config_path = os.path.join(
Expand Down
9 changes: 7 additions & 2 deletions src/instructlab/sdg/blocks/llmblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def generate(self, samples: Dataset) -> Dataset:
Generate the output from the block. This method should first validate the input data,
then generate the output, and finally parse the generated output before returning it.
:return: The parsed output after generation.
Args:
samples (Dataset): The samples used as input data
Returns:
The parsed output after generation.
"""
num_samples = self.batch_params.get("num_samples", None)
logger.debug("Generating outputs for {} samples".format(len(samples)))
Expand Down Expand Up @@ -407,7 +411,8 @@ def __init__(
# Generate the output from the block. This method should first validate the input data,
# then generate the output, and finally parse the generated output before returning it.

# :return: The parsed output after generation.
# Returns:
# The parsed output after generation.
# """
# num_samples = self.block_config.get("num_samples", None)
# logger.debug("Generating outputs for {} samples".format(len(samples)))
Expand Down
25 changes: 17 additions & 8 deletions src/instructlab/sdg/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def register(cls, block_name: str):
"""
Decorator to register a block class under a specified name.
:param block_name: Name under which to register the block.
Args:
block_name (str): Name under which to register the block.
"""

def decorator(block_class):
Expand All @@ -35,7 +36,8 @@ def get_registry(cls):
"""
Retrieve the current registry map of block types.
:return: Dictionary of registered block names and classes.
Returns:
Dictionary of registered block names and classes.
"""
return cls._registry

Expand All @@ -47,10 +49,13 @@ class PromptRegistry:

@classmethod
def register(cls, *names: str):
"""Decorator to register a Jinja2 template function by name.
"""Decorator to register Jinja2 template functions by name.
:param name: Name of the template to register.
:return: A decorator that registers the Jinja2 template function.
Args:
names (str): Names of the templates to register.
Returns:
A decorator that registers the Jinja2 template functions.
"""

def decorator(func):
Expand All @@ -67,8 +72,11 @@ def decorator(func):
def get_template(cls, name: str) -> Template:
"""Retrieve a Jinja2 template by name.
:param name: Name of the template to retrieve.
:return: The Jinja2 template instance.
Args:
name (str): Name of the template to retrieve.
Returns:
The Jinja2 template instance.
"""
if name not in cls._registry:
raise KeyError(f"Prompt template '{name}' not found.")
Expand All @@ -79,6 +87,7 @@ def get_registry(cls):
"""
Retrieve the current registry map of block types.
:return: Dictionary of registered block names and classes.
Returns:
Dictionary of registered block names and classes.
"""
return cls._registry

0 comments on commit 7abde97

Please sign in to comment.