Skip to content

Commit

Permalink
Small fixes in Argilla steps
Browse files Browse the repository at this point in the history
Fix `ImportError` for `argilla` in `TextGenerationToArgilla`
  • Loading branch information
gabrielmbmb committed Apr 11, 2024
1 parent 4361fbd commit 6a26b5e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
14 changes: 11 additions & 3 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,17 @@ class Argilla(Step, ABC):
- dynamic, based on the `inputs` value provided
"""

dataset_name: str
dataset_workspace: Optional[str] = None
dataset_name: RuntimeParameter[str] = Field(
default=None, description="The name of the dataset in Argilla."
)
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The workspace where the dataset will be created in Argilla. Defaults"
"to `None` which means it will be created in the default workspace.",
)

api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv("ARGILLA_BASE_URL"),
default_factory=lambda: os.getenv("ARGILLA_API_URL"),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
Expand Down Expand Up @@ -122,6 +128,8 @@ def load(self) -> None:
"""
super().load()

self._rg_init()

@property
@abstractmethod
def inputs(self) -> List[str]:
Expand Down
2 changes: 0 additions & 2 deletions src/distilabel/steps/argilla/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,6 @@ def load(self) -> None:
"""
super().load()

self._rg_init()

# Both `instruction` and `generations` will be used as the fields of the dataset
self._instruction = self.input_mappings.get("instruction", "instruction")
self._generations = self.input_mappings.get("generations", "generations")
Expand Down
8 changes: 2 additions & 6 deletions src/distilabel/steps/argilla/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@

try:
import argilla as rg
except ImportError as ie:
raise ImportError(
"Argilla is not installed. Please install it using `pip install argilla`."
) from ie
except ImportError:
pass

from distilabel.steps.argilla.base import Argilla
from distilabel.steps.base import StepInput
Expand Down Expand Up @@ -71,8 +69,6 @@ def load(self) -> None:
"""
super().load()

self._rg_init()

self._instruction = self.input_mappings.get("instruction", "instruction")
self._generation = self.input_mappings.get("generation", "generation")

Expand Down

0 comments on commit 6a26b5e

Please sign in to comment.