Skip to content
This repository has been archived by the owner on Sep 24, 2024. It is now read-only.

Commit

Permalink
rename to text dataset config
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Friedowitz committed Jan 19, 2024
1 parent 70aef27 commit 3708f5c
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/flamingo/integrations/huggingface/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from .dataset_config import DatasetConfig
from .model_config import AutoModelConfig
from .quantization_config import QuantizationConfig
from .text_dataset_config import TextDatasetConfig
from .tokenizer_config import AutoTokenizerConfig
from .trainer_config import TrainerConfig

__all__ = [
"AutoModelConfig",
"AutoTokenizerConfig",
"DatasetConfig",
"QuantizationConfig",
"TextDatasetConfig",
"TrainerConfig",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
from flamingo.integrations.wandb import WandbArtifactConfig
from flamingo.types import BaseFlamingoConfig

DEFAULT_TEXT_FIELD: str = "text"

class DatasetConfig(BaseFlamingoConfig):
"""Settings passed to load a HuggingFace dataset."""

class TextDatasetConfig(BaseFlamingoConfig):
"""Settings passed to load a HuggingFace text dataset.
The dataset should contain a single text column named by the `text_field` parameter.
"""

path: str | WandbArtifactConfig
split: str | None = None
text_field: str = "text"
text_field: str = DEFAULT_TEXT_FIELD
test_size: float | None = None
seed: int | None = None

Expand Down
6 changes: 3 additions & 3 deletions src/flamingo/jobs/finetuning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from flamingo.integrations.huggingface import (
AutoModelConfig,
AutoTokenizerConfig,
DatasetConfig,
QuantizationConfig,
TextDatasetConfig,
TrainerConfig,
)
from flamingo.integrations.wandb import WandbRunConfig
Expand All @@ -27,7 +27,7 @@ class FinetuningJobConfig(BaseFlamingoConfig):
"""Configuration to submit an LLM finetuning job."""

model: AutoModelConfig
dataset: DatasetConfig
dataset: TextDatasetConfig
tokenizer: AutoTokenizerConfig
quantization: QuantizationConfig | None = None
adapter: LoraConfig | None = None # TODO: Create own dataclass here
Expand Down Expand Up @@ -61,7 +61,7 @@ def validate_model_arg(cls, x):
def validate_dataset_arg(cls, x):
"""Allow for passing just a path string as the dataset argument."""
if isinstance(x, str):
return DatasetConfig(path=x)
return TextDatasetConfig(path=x, text_field="text")
return x

@validator("tokenizer", pre=True, always=True)
Expand Down
10 changes: 7 additions & 3 deletions tests/jobs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from flamingo.integrations.huggingface import (
AutoModelConfig,
AutoTokenizerConfig,
DatasetConfig,
QuantizationConfig,
TextDatasetConfig,
)
from flamingo.integrations.wandb import WandbArtifactConfig, WandbRunConfig

Expand Down Expand Up @@ -34,13 +34,17 @@ def tokenizer_config_with_artifact():

@pytest.fixture
def dataset_config_with_path():
return DatasetConfig(path="databricks/dolly15k", split="train")
return TextDatasetConfig(
path="databricks/dolly15k",
text_field="text",
split="train",
)


@pytest.fixture
def dataset_config_with_artifact():
artifact = WandbArtifactConfig(name="dataset")
return DatasetConfig(path=artifact, split="train")
return TextDatasetConfig(path=artifact, split="train")


@pytest.fixture
Expand Down
8 changes: 6 additions & 2 deletions tests/jobs/finetuning/test_finetuning_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import pytest
from pydantic import ValidationError

from flamingo.integrations.huggingface import AutoModelConfig, AutoTokenizerConfig, DatasetConfig
from flamingo.integrations.huggingface import (
AutoModelConfig,
AutoTokenizerConfig,
TextDatasetConfig,
)
from flamingo.jobs.finetuning import FinetuningJobConfig, FinetuningRayConfig


Expand Down Expand Up @@ -60,7 +64,7 @@ def test_argument_validation():
)
assert allowed_config.model == AutoModelConfig(path="model_path")
assert allowed_config.tokenizer == AutoTokenizerConfig(path="tokenizer_path")
assert allowed_config.dataset == DatasetConfig(path="dataset_path")
assert allowed_config.dataset == TextDatasetConfig(path="dataset_path")

# Check passing invalid arguments is validated for each asset type
with pytest.raises(ValidationError):
Expand Down

0 comments on commit 3708f5c

Please sign in to comment.