Skip to content

Commit

Permalink
Fix RuntimeParamater validation when provided as _Step attr (#564)
Browse files Browse the repository at this point in the history
* Fix `RuntimeParamater` validation when provided as `_Step` attr

* Add `test_validate_step_process_runtime_parameters`

Also included `offset` in `GeneratorStep.process` subclasses, to ensure that the correct exception is captured

* Revert adding `offset` to `test_generator_step_process_without_offset_parameter`

* Use `is not None` over `not`

Co-authored-by: Gabriel Martin <[email protected]>

* Fix `RuntimeParameters` assignment when not provided via `_Step` attr

Co-authored-by: Gabriel Martin <[email protected]>

* Fix `_get_attribute_default` check when reverting `not` over `is not None`

It should have been `is None` instead

* Bump version to `1.0.2`

* Fix `llm_blender`installation from `argilla-io` fork (#557)

* Add `test_process_with_runtime_parameters`

---------

Co-authored-by: Gabriel Martin <[email protected]>
  • Loading branch information
alvarobartt and gabrielmbmb authored Apr 24, 2024
1 parent b870f39 commit c0520d8
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
else
pip install -e .[dev,tests,cohere,hf-transformers,hf-inference-endpoints,vertexai,ollama,openai,together,argilla,llama-cpp,anthropic,litellm]
fi
pip install git+https://github.com/yuchenlin/LLM-Blender.git@3c2d71f
pip install git+https://github.com/argilla-io/LLM-Blender.git
- name: Lint
run: make lint
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

from rich import traceback as rich_traceback

__version__ = "1.0.1"
__version__ = "1.0.2"

rich_traceback.install(show_locals=True)
1 change: 1 addition & 0 deletions src/distilabel/mixins/runtime_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
attr = getattr(self, name)
if isinstance(attr, RuntimeParametersMixin):
attr.set_runtime_parameters(value)
self._runtime_parameters[name] = value
continue

# Handle settings values for `_SecretField`
Expand Down
30 changes: 27 additions & 3 deletions src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,9 @@ def _validate_process_step_input_parameter(
f" to receive outputs from previous steps."
)

def _validate_step_process_runtime_parameters(self, step: "_Step") -> None:
def _validate_step_process_runtime_parameters( # noqa: C901
self, step: "_Step"
) -> None:
"""Validates that the required runtime parameters of the step are provided. A
runtime parameter is considered required if it doesn't have a default value. The
name of the runtime parameters are separated by dots to represent nested parameters.
Expand All @@ -367,6 +369,18 @@ def _get_pipeline_aux_code(step_name: str, param_name: str) -> str:
result += nested_dict + "})"
return result

def _get_attribute_default(
step: "_Step", composed_param_name: str
) -> Union[Any, None]:
parts = composed_param_name.split(".")
attr = step
for part in parts:
if isinstance(attr, dict):
attr = attr.get(part, None)
elif isinstance(attr, object):
attr = getattr(attr, part)
return attr

def _check_required_parameter(
param_name: str,
composed_param_name: str,
Expand All @@ -381,12 +395,22 @@ def _check_required_parameter(
param_name=subparam,
composed_param_name=f"{composed_param_name}.{subparam}",
is_optional_or_nested=value,
runtime_parameters=runtime_parameters.get(subparam, {}),
# NOTE: `runtime_parameters` get is for the specific case of `LLM` in `Task`
runtime_parameters=runtime_parameters.get(
param_name, runtime_parameters
),
runtime_parameters_names=runtime_parameters_names,
)
return

if not is_optional_or_nested and param_name not in runtime_parameters:
if (
not is_optional_or_nested
and param_name not in runtime_parameters
and _get_attribute_default(
step=step, composed_param_name=composed_param_name
)
is None
):
aux_code = _get_pipeline_aux_code(step.name, composed_param_name)
raise ValueError(
f"Step '{step.name}' is missing required runtime parameter '{param_name}'."
Expand Down
34 changes: 32 additions & 2 deletions tests/unit/pipeline/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["response"]

def process(self) -> "GeneratorStepOutput":
def process(self, offset: int = 0) -> "GeneratorStepOutput":
yield [{"response": "response1"}], False

step = DummyGeneratorStep(name="dummy_generator_step", pipeline=pipeline) # type: ignore
Expand All @@ -334,6 +334,34 @@ def process(self) -> "GeneratorStepOutput":
):
dag.validate()

def test_validate_step_process_runtime_parameters(
self, pipeline: "Pipeline"
) -> None:
class DummyGeneratorStep(GeneratorStep):
runtime_param1: RuntimeParameter[int]
runtime_param2: RuntimeParameter[int] = 5

@property
def inputs(self) -> List[str]:
return ["instruction"]

@property
def outputs(self) -> List[str]:
return ["response"]

def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
yield [{"response": "response1"}], False

step = DummyGeneratorStep(
name="dummy_generator_step", runtime_param1=2, pipeline=pipeline
)
step.set_runtime_parameters({})

dag = DAG()
dag.add_step(step)

dag.validate()

def test_step_invalid_input_mappings(self, pipeline: "Pipeline") -> None:
class DummyStep(Step):
@property
Expand Down Expand Up @@ -402,7 +430,9 @@ def inputs(self) -> List[str]:
def outputs(self) -> List[str]:
return ["response"]

def process(self, *inputs: StepInput) -> "GeneratorStepOutput": # type: ignore
def process(
self, *inputs: StepInput, offset: int = 0
) -> "GeneratorStepOutput": # type: ignore
yield [{"response": "response1"}], False

step = DummyGeneratorStep(name="dummy_generator_step", pipeline=pipeline)
Expand Down
57 changes: 56 additions & 1 deletion tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Union
from dataclasses import field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import pytest
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.base import Task
from pydantic import ValidationError
Expand All @@ -40,6 +42,11 @@ def format_output(self, output: Union[str, None], input: Dict[str, Any]) -> dict
return {"output": output}


class DummyRuntimeLLM(DummyLLM):
runtime_parameter: RuntimeParameter[int]
runtime_parameter_optional: Optional[RuntimeParameter[int]] = field(default=None)


class TestTask:
def test_passing_pipeline(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
Expand Down Expand Up @@ -112,6 +119,54 @@ def test_process(
result = next(task.process([{"instruction": "test"}]))
assert result == expected

def test_process_with_runtime_parameters(self) -> None:
# 1. Runtime parameters provided
llm = DummyRuntimeLLM() # type: ignore
task = DummyTask(
name="task",
llm=llm,
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.set_runtime_parameters({"llm": {"runtime_parameter": 1}})
assert task.load() is None
assert task.llm.runtime_parameter == 1 # type: ignore
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {"kwargs": False},
}

# 2. Runtime parameters in init
llm = DummyRuntimeLLM(runtime_parameter=1)
task = DummyTask(
name="task",
llm=llm,
pipeline=Pipeline(name="unit-test-pipeline"),
)
assert task.load() is None
assert task.llm.runtime_parameter == 1 # type: ignore
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {"kwargs": False},
}

# 3. Runtime parameters in init superseded by runtime parameters
llm = DummyRuntimeLLM(runtime_parameter=1)
task = DummyTask(
name="task",
llm=llm,
pipeline=Pipeline(name="unit-test-pipeline"),
)
task.set_runtime_parameters({"llm": {"runtime_parameter": 2}})
assert task.load() is None
assert task.llm.runtime_parameter == 2 # type: ignore
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {"kwargs": False},
}

def test_serialization(self) -> None:
pipeline = Pipeline(name="unit-test-pipeline")
llm = DummyLLM()
Expand Down

0 comments on commit c0520d8

Please sign in to comment.