Skip to content

Commit

Permalink
feat: add basic draw implementation to pipline (#966)
Browse files Browse the repository at this point in the history
* feat: add basic draw implementation to pipline

* refactor: cleanup some code

* feat: add functionality to draw TD or LR

* refactor: remove step name from vis

* refactor: default to LR generation

* Add dag with mapping

* feat: add edge labels

* Remove images

* feat: add support for leaf node to argilla and distilabel

* refactor: order of functions

* test: Add tests

* fix: replace logger warning for `warning.warn` to avoid non-initialized logger

* fix: avoid potentially getting raised errors during `get_outputs` call relying on dynamic calls

* docs: Add visualizing pipelines section

* feat: Add a try-except around pipeline visualization in Notebook to ensure it will never be a blocking action

* feat: add a show method to the pipleines for visualizing in notebooks

* docs: add more context on pipeline.show

* Apply suggestions from code review

Co-authored-by: Agus <[email protected]>

* Update src/distilabel/steps/generators/huggingface.py

* feat: remove show to simplify flow

* refactor: mermaid URL at top as constant

* feat: improve flow for passing by info to a potential next step

* docs: update docstring

---------

Co-authored-by: Agus <[email protected]>
  • Loading branch information
davidberenstein1957 and plaguss authored Sep 20, 2024
1 parent ad231ab commit c7deafa
Show file tree
Hide file tree
Showing 6 changed files with 389 additions and 9 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 28 additions & 2 deletions docs/sections/how_to_guides/basic/pipeline/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ with Pipeline("pipe-name", description="My first pipe") as pipeline:
VertexAILLM(model="gemini-1.5-pro"),
):
task = TextGeneration(
name=f"text_generation_with_{llm.model_name}",
name=f"text_generation_with_{llm.model_name.replace('.', '-')}",
llm=llm,
input_batch_size=5,
)
Expand Down Expand Up @@ -459,6 +459,30 @@ To load the pipeline, we can use the `from_yaml` or `from_json` methods:

Serializing the pipeline is very useful when we want to share the pipeline with others, or when we want to store the pipeline for future use. It can even be hosted online, so the pipeline can be executed directly using the [CLI](../../advanced/cli/index.md).

## Visualizing the pipeline

We can visualize the pipeline using the `Pipeline.draw()` method. This will create a `mermaid` graph, and return the path to the image.

```python
path_to_image = pipeline.draw(
top_to_bottom=True,
show_edge_labels=True,
)
```

Within notebooks, we can simply call `pipeline` and the graph will be displayed. Alternatively, we can use the `Pipeline.draw()` method to have more control over the graph visualization and use `IPython` to display it.

```python
from IPython.display import Image, display

display(Image(path_to_image))
```

Let's now see how the pipeline of the [fully working example](#fully-working-example) looks like.

![Pipeline](../../../../assets/images/sections/how_to_guides/basic/pipeline.png)


## Fully working example

To sum up, here is the full code of the pipeline we have created in this section. Note that you will need to change the name of the Hugging Face repository where the resulting will be pushed, set `OPENAI_API_KEY` environment variable, set `MISTRAL_API_KEY` and have `gcloud` installed and configured:
Expand Down Expand Up @@ -487,7 +511,9 @@ To sum up, here is the full code of the pipeline we have created in this section
MistralLLM(model="mistral-large-2402"),
VertexAILLM(model="gemini-1.0-pro"),
):
task = TextGeneration(name=f"text_generation_with_{llm.model_name}", llm=llm)
task = TextGeneration(
name=f"text_generation_with_{llm.model_name.replace('.', '-')}", llm=llm
)
load_dataset.connect(task)
task.connect(combine_generations)

Expand Down
190 changes: 188 additions & 2 deletions src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import base64
import inspect
from collections import defaultdict
from functools import cached_property
Expand All @@ -29,6 +29,7 @@
)

import networkx as nx
import requests

from distilabel.constants import (
CONVERGENCE_STEP_ATTR_NAME,
Expand All @@ -48,6 +49,8 @@
from distilabel.mixins.runtime_parameters import RuntimeParametersNames
from distilabel.steps.base import GeneratorStep, Step, _Step

_MERMAID_URL = "https://mermaid.ink/img/"


class DAG(_Serializable):
"""A Directed Acyclic Graph (DAG) to represent the pipeline.
Expand Down Expand Up @@ -452,7 +455,7 @@ def _validate_convergence_step(

# Check if the `input_batch_size` of the step is equal or lower than the
for predecessor in predecessors:
prev_step: "Step" = self.get_step(predecessor)[STEP_ATTR_NAME]
prev_step: "Step" = self.get_step(predecessor)[STEP_ATTR_NAME] # type: ignore
if step.input_batch_size > prev_step.input_batch_size: # type: ignore
raise ValueError(
"A convergence step should have an `input_batch_size` equal or lower"
Expand Down Expand Up @@ -749,3 +752,186 @@ def from_dict(cls, data: Dict[str, Any]) -> "DAG":
)

return dag

def _get_graph_info_for_draw(
self,
) -> Tuple[
Set[str],
Dict[str, str],
List[Dict[str, Any]],
Dict[str, Dict[str, Any]],
Dict[str, Dict[str, Any]],
Dict[str, Dict[str, Any]],
]:
"""Returns the graph info.
Returns:
all_steps: The set of all steps in the graph.
step_name_to_class: The mapping of step names to their classes.
connections: The list of connections in the graph.
step_outputs: The mapping of step names to their outputs.
step_output_mappings: The mapping of step names to their output mappings.
step_input_mappings: The mapping of step names to their input mappings.
"""
dump = self.dump()
step_name_to_class = {
step["step"].get("name"): step["step"].get("type_info", {}).get("name")
for step in dump["steps"]
}
connections = dump["connections"]

step_outputs = {}
for step in dump["steps"]:
try:
step_outputs[step["name"]] = self.get_step(step["name"])[
STEP_ATTR_NAME
].get_outputs()
except AttributeError:
step_outputs[step["name"]] = {"dynamic": True}
step_inputs = {}
for step in dump["steps"]:
try:
step_inputs[step["name"]] = self.get_step(step["name"])[
STEP_ATTR_NAME
].get_inputs()
except AttributeError:
step_inputs[step["name"]] = {"dynamic": True}

# Add Argilla and Distiset steps to the graph
leaf_steps = self.leaf_steps
for idx, leaf_step in enumerate(leaf_steps):
if "to_argilla" in leaf_step:
connections.append({"from": leaf_step, "to": [f"to_argilla_{idx}"]})
step_name_to_class[f"to_argilla_{idx}"] = "Argilla"
step_outputs[leaf_step] = {"records": True}
else:
connections.append({"from": leaf_step, "to": [f"distiset_{idx}"]})
step_name_to_class[f"distiset_{idx}"] = "Distiset"

# Create a set of all steps in the graph
all_steps = {con["from"] for con in connections} | {
to_step for con in connections for to_step in con["to"]
}

# Create a mapping of step outputs
step_output_mappings = {
step["name"]: {
k: v
for k, v in {
**{output: output for output in step_outputs[step["name"]]},
**step["step"]["output_mappings"],
}.items()
if list(
dict(
{
**{output: output for output in step_outputs[step["name"]]},
**step["step"]["output_mappings"],
}.items()
).values()
).count(v)
== 1
or k != v
}
for step in dump["steps"]
}
step_input_mappings = {
step["name"]: dict(
{
**{input: input for input in step_inputs[step["name"]]},
**step["step"]["input_mappings"],
}.items()
)
for step in dump["steps"]
}

return (
all_steps,
step_name_to_class,
connections,
step_outputs,
step_output_mappings,
step_input_mappings,
)

def draw(self, top_to_bottom: bool = False, show_edge_labels: bool = True) -> str: # noqa: C901
"""Draws the DAG and returns the image content.
Parameters:
top_to_bottom: Whether to draw the DAG top to bottom. Defaults to `False`.
show_edge_labels: Whether to show the edge labels. Defaults to `True`.
Returns:
The image content.
"""
(
all_steps,
step_name_to_class,
connections,
step_outputs,
step_output_mappings,
step_input_mappings,
) = self._get_graph_info_for_draw()
graph = [f"flowchart {'TD' if top_to_bottom else 'LR'}"]
for step in all_steps:
graph.append(f' {step}["{step_name_to_class[step]}"]')

if show_edge_labels:
for connection in connections:
from_step = connection["from"]
from_mapping = step_output_mappings[from_step]
for to_step in connection["to"]:
for from_column in set(
list(step_outputs[from_step].keys())
+ list(step_output_mappings[from_step].keys())
):
if from_column not in from_mapping:
continue
to_column = from_mapping.get(from_column)

# walk through mappings
to_mapping = step_input_mappings.get(to_step, {})
edge_label = [from_column]
if from_column != to_column:
edge_label.append(to_column)
if edge_label[-1] in to_mapping:
edge_label.append(to_mapping[edge_label[-1]])

if (
edge_label[-1] not in to_mapping
and from_step not in self.leaf_steps
):
edge_label.append("**_pass_**")
edge_label = ":".join(list(dict.fromkeys(edge_label)))
graph.append(f" {from_step} --> |{edge_label}| {to_step}")

else:
for connection in connections:
from_step = connection["from"]
for to_step in connection["to"]:
graph.append(f" {from_step} --> {to_step}")

graph.append("classDef component text-align:center;")
graph_styled = "\n".join(graph)
return _to_mermaid_image(graph_styled)


def _to_mermaid_image(graph_styled: str) -> str:
"""Converts a Mermaid graph to an image using the Mermaid Ink service.
Parameters:
graph_styled: The Mermaid graph to convert to an image.
Returns:
The image content.
"""
base64_string = base64.b64encode(graph_styled.encode("ascii")).decode("ascii")
url = f"{_MERMAID_URL}{base64_string}?type=png"

try:
response = requests.get(url, timeout=10)
response.raise_for_status()
return response.content
except requests.RequestException as e:
raise ValueError(
"Error accessing https://mermaid.ink/. See stacktrace for details."
) from e
41 changes: 40 additions & 1 deletion src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import logging
import os
Expand Down Expand Up @@ -50,6 +49,7 @@
from distilabel.steps.base import GeneratorStep
from distilabel.steps.generators.utils import make_generator_step
from distilabel.utils.logging import setup_logging, stop_logging
from distilabel.utils.notebook import in_notebook
from distilabel.utils.serialization import (
TYPE_INFO_KEY,
_Serializable,
Expand Down Expand Up @@ -638,6 +638,45 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""
return self.dag.dump()

def draw(
self,
path: Optional[Union[str, Path]] = "pipeline.png",
top_to_bottom: bool = False,
show_edge_labels: bool = True,
) -> str:
"""
Draws the pipeline.
Parameters:
path: The path to save the image to.
top_to_bottom: Whether to draw the DAG top to bottom. Defaults to `False`.
show_edge_labels: Whether to show the edge labels. Defaults to `True`.
Returns:
The path to the saved image.
"""
png = self.dag.draw(
top_to_bottom=top_to_bottom, show_edge_labels=show_edge_labels
)
with open(path, "wb") as f:
f.write(png)
return path

def __repr__(self) -> str:
"""
If running in a Jupyter notebook, display an image representing this `Pipeline`.
"""
if in_notebook():
try:
from IPython.display import Image, display

image_data = self.dag.draw()

display(Image(image_data))
except Exception:
pass
return super().__repr__()

def dump(self, **kwargs: Any) -> Dict[str, Any]:
return {
"distilabel": {"version": __version__},
Expand Down
9 changes: 5 additions & 4 deletions src/distilabel/steps/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from collections import defaultdict
from functools import cached_property
from pathlib import Path
Expand Down Expand Up @@ -247,10 +248,10 @@ def _dataset_info(self) -> Dict[str, DatasetInfo]:
try:
return get_dataset_infos(self.repo_id)
except Exception as e:
# The previous could fail in case of a internet connection issues.
# Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway.
self._logger.warning(
f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}"
warnings.warn(
f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}",
UserWarning,
stacklevel=2,
)
ds = load_dataset(self.repo_id, config=self.config, split=self.split)
if self.config:
Expand Down
Loading

0 comments on commit c7deafa

Please sign in to comment.