diff --git a/.github/workflows/dependent_prs.yaml b/.github/workflows/dependent_prs.yaml new file mode 100644 index 00000000..00724e60 --- /dev/null +++ b/.github/workflows/dependent_prs.yaml @@ -0,0 +1,14 @@ +name: Dependent/Blocking PRs + +on: + pull_request_target: + types: [opened, edited, closed, reopened] + +jobs: + check_dependencies: + runs-on: ubuntu-latest + name: Check Dependencies + steps: + - uses: gregsdennis/dependencies-action@main + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/pr_checks.yaml b/.github/workflows/pr_checks.yaml new file mode 100644 index 00000000..b6b5dcb1 --- /dev/null +++ b/.github/workflows/pr_checks.yaml @@ -0,0 +1,34 @@ +name: PR Checks + +on: [push] + +jobs: + pytest_ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Set up Python 3.10 + uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Install test dependencies + run: | + pip install -r requirements/test.txt + continue-on-error: true + + - name: Lint with Ruff + run: | + ruff --output-format=github . + continue-on-error: false + + - name: Install full dependencies + run: | + pip install ".[test]" + continue-on-error: true + + - name: Run unit tests + run: | + pytest + continue-on-error: false diff --git a/.gitignore b/.gitignore index 68bc17f9..f845eb4a 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,11 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Ruff +.ruff_cache + + +# ignore local wandb cache files. Not perfect +**/wandb/*.log +**/wandb/*run* diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..288c2d9b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,12 @@ +{ + "python.analysis.importFormat": "absolute", + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "never", + "source.organizeImports.ruff": "explicit" + } + }, + "python.testing.pytestEnabled": true +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..e14aa8b4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,101 @@ +[build-system] +requires = ["setuptools", "setuptools-scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "flamingo" +version = "0.1.0" +description = "Ray-centric job library for training and evaluation" +readme = "README.md" +requires-python = ">=3.10,<3.11" + +dependencies = [ + "click==8.1.7", + "ray[default]==2.7.0", + "torch==2.1.0", + "scipy==1.10.1", + "wandb==0.16.1", + "pydantic-yaml==1.2.0", + "pydantic==1.10.8", +] + +[project.optional-dependencies] +finetune = [ + "datasets==2.15.0", + "transformers==4.36.2", + "accelerate==0.25.0", + "peft==0.7.1", + "trl==0.7.4", + "bitsandbytes==0.41.3", +] + +evaluate = ["lm-eval==0.4.0", "einops"] + +test = ["ruff==0.1.4", "pytest==7.4.3", "pytest-cov==4.1.0"] + +all = ["flamingo[finetune,evaluate,test]"] + + +[tool.pytest.ini_options] +addopts = "-v --cov src --no-cov-on-fail --disable-warnings" +testpaths = ["tests"] + +[tool.ruff] +target-version = "py310" +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] +line-length = 100 + + +[tool.ruff.lint] +select = [ + "E", # pycodestyle + "F", # pyflakes + "UP", # pyupgrade + "I", # import sorting + "N", # pep8 naming + "ISC", # flake8 implicit string concat + "PTH", # flake8-use-pathlib use Path library + "PD", # pandas-vet +] + +ignore = [ + "D417", # documentation for every function parameter. + "N806", # ignore uppercased variables + "N812", # import as uppercased + "N803", # lowercased args + "N817", # imported as acryonym + "B023", # doesn't bind loop var, we do this a lot in torch + "D100", # module-level docstrings + "N805", # first param needs to be self; pydantic breaks this sometimes +] + +# Avoid trying to fix some violations +unfixable = ["B", "SIM", "TRY", "RUF"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "auto" +skip-magic-trailing-comma = false diff --git a/src/flamingo/__init__.py b/src/flamingo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/flamingo/cli.py b/src/flamingo/cli.py new file mode 100644 index 00000000..80743421 --- /dev/null +++ b/src/flamingo/cli.py @@ -0,0 +1,14 @@ +import click + + +@click.group() +def main(): + pass + + +# need to add the group / command function itself, not the module +main.add_command(simple.driver) + + +if __name__ == "__main__": + main() diff --git a/src/flamingo/integrations/huggingface/__init__.py b/src/flamingo/integrations/huggingface/__init__.py new file mode 100644 index 00000000..f704e2c5 --- /dev/null +++ b/src/flamingo/integrations/huggingface/__init__.py @@ -0,0 +1,4 @@ +from .quantization_config import QuantizationConfig +from .utils import is_valid_huggingface_model_name + +__all__ = ["QuantizationConfig", "is_valid_huggingface_model_name"] diff --git a/src/flamingo/integrations/huggingface/quantization_config.py b/src/flamingo/integrations/huggingface/quantization_config.py new file mode 100644 index 00000000..8443c397 --- /dev/null +++ b/src/flamingo/integrations/huggingface/quantization_config.py @@ -0,0 +1,25 @@ +from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype +from transformers import BitsAndBytesConfig + + +class QuantizationConfig(BaseFlamingoConfig): + """Basic quantization settings to pass to training and evaluation jobs. + + Note that in order to use BitsAndBytes quantization on Ray, + you must ensure that the runtime environment is installed with GPU support. + This can be configured by setting the `entrypoint_num_gpus > 0` when submitting a job + to the cluster, e.g., + """ + + load_in_8bit: bool | None = None + load_in_4bit: bool | None = None + bnb_4bit_quant_type: str = "fp4" + bnb_4bit_compute_dtype: SerializableTorchDtype = None + + def as_huggingface(self) -> BitsAndBytesConfig: + return BitsAndBytesConfig( + load_in_4bit=self.load_in_4bit, + load_in_8bit=self.load_in_8bit, + bnb_4bit_compute_dtype=self.bnb_4bit_compute_dtype, + bnb_4bit_quant_type=self.bnb_4bit_quant_type, + ) diff --git a/src/flamingo/integrations/huggingface/utils.py b/src/flamingo/integrations/huggingface/utils.py new file mode 100644 index 00000000..21d4ce54 --- /dev/null +++ b/src/flamingo/integrations/huggingface/utils.py @@ -0,0 +1,16 @@ +from huggingface_hub.utils import HFValidationError, validate_repo_id + + +def is_valid_huggingface_model_name(s: str): + """ + Simple test to check if an HF model is valid using HuggingFace's tools. + Sadly, theirs throws an exception and has no return. + + Args: + s: string to test. + """ + try: + validate_repo_id(s) + return True + except HFValidationError: + return False diff --git a/src/flamingo/integrations/wandb/__init__.py b/src/flamingo/integrations/wandb/__init__.py new file mode 100644 index 00000000..398b7666 --- /dev/null +++ b/src/flamingo/integrations/wandb/__init__.py @@ -0,0 +1,10 @@ +from .wandb_environment import WandbEnvironment # noqa: I001 +from .wandb_mixin import WandbEnvironmentMixin +from .utils import get_wandb_summary, update_wandb_summary + +__all__ = [ + "WandbEnvironment", + "WandbEnvironmentMixin", + "get_wandb_summary", + "update_wandb_summary", +] diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py new file mode 100644 index 00000000..b7c270d4 --- /dev/null +++ b/src/flamingo/integrations/wandb/utils.py @@ -0,0 +1,41 @@ +from typing import Any + +from flamingo.integrations.wandb import WandbEnvironment + +import wandb +from wandb.apis.public import Run + + +def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]: + """Get the summary dictionary attached to a W&B run.""" + run = _resolve_wandb_run(env) + return dict(run.summary) + + +def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None: + """Update a run's summary with the provided metrics.""" + run = _resolve_wandb_run(env) + run.summary.update(metrics) + run.update() + + +def _resolve_wandb_run(env: WandbEnvironment) -> Run: + """Resolve a WandB run object from the provided environment settings. + + An exception is raised if a Run cannot be found, + or if multiple runs exist in scope with the same name. + """ + api = wandb.Api() + base_path = "/".join(x for x in (env.entity, env.project) if x) + if env.run_id is not None: + full_path = f"{base_path}/{env.run_id}" + return api.run(full_path) + else: + match [run for run in api.runs(base_path) if run.name == env.name]: + case []: + raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}") + case [Run(), _]: + raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}") + case [Run()] as mr: + # we have a single one, hurray + return mr[0] diff --git a/src/flamingo/integrations/wandb/wandb_environment.py b/src/flamingo/integrations/wandb/wandb_environment.py new file mode 100644 index 00000000..3fcf2281 --- /dev/null +++ b/src/flamingo/integrations/wandb/wandb_environment.py @@ -0,0 +1,69 @@ +import os +import warnings + +from flamingo.types import BaseFlamingoConfig +from pydantic import Extra, root_validator + +from wandb.apis.public import Run + + +class WandbEnvironment(BaseFlamingoConfig): + """Settings required to log to a W&B run. + + The fields on this class map to the environment variables + that are used to control the W&B logging locations. + + The `name` and `project` are required as they are the minimum information + required to identify a run. The `name` is the human-readable name that appears in the W&B UI. + `name` is different than the `run_id` which must be unique within a project. + Although the `name` is not mandatorily unique, it is generally best practice to use a + unique and descriptive name to later identify the run. + """ + + class Config: + extra = Extra.forbid # Error on extra kwargs + + __match_args__ = ("name", "project", "run_id", "run_group", "entity") + + name: str + project: str + run_id: str | None = None + run_group: str | None = None + entity: str | None = None + + @root_validator(pre=True) + def warn_missing_api_key(cls, values): + if not os.environ.get("WANDB_API_KEY", None): + warnings.warn( + "Cannot find `WANDB_API_KEY` in your environment. " + "Tracking will fail if a default key does not exist on the Ray cluster." + ) + return values + + @property + def env_vars(self) -> dict[str, str]: + # WandB w/ HuggingFace is weird. You can specify the run name inline, + # but the rest must be injected as environment variables + env_vars = { + "WANDB_NAME": self.name, + "WANDB_PROJECT": self.project, + "WANDB_RUN_ID": self.run_id, + "WANDB_RUN_GROUP": self.run_group, + "WANDB_ENTITY": self.entity, + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), + } + return {k: v for k, v in env_vars.items() if v is not None} + + @classmethod + def from_run(cls, run: Run) -> "WandbEnvironment": + """Extract environment settings from a W&B Run object. + + Useful when listing runs from the W&B API and extracting their settings for a job. + """ + # TODO: Can we get the run group from this when it exists? + return cls( + name=run.name, + project=run.project, + entity=run.entity, + run_id=run.id, + ) diff --git a/src/flamingo/integrations/wandb/wandb_mixin.py b/src/flamingo/integrations/wandb/wandb_mixin.py new file mode 100644 index 00000000..56266b81 --- /dev/null +++ b/src/flamingo/integrations/wandb/wandb_mixin.py @@ -0,0 +1,22 @@ +from flamingo.integrations.wandb import WandbEnvironment +from flamingo.types import BaseFlamingoConfig + + +class WandbEnvironmentMixin(BaseFlamingoConfig): + """Mixin for a config that contains W&B environment settings.""" + + wandb_env: WandbEnvironment | None = None + + @property + def env_vars(self) -> dict[str, str]: + return self.wandb_env.env_vars if self.wandb_env else {} + + @property + def wandb_name(self) -> str | None: + """Return the W&B run name, if it exists.""" + return self.wandb_env.name if self.wandb_env else None + + @property + def wandb_project(self) -> str | None: + """Return the W&B project name, if it exists.""" + return self.wandb_env.project if self.wandb_env else None diff --git a/src/flamingo/jobs/__init__.py b/src/flamingo/jobs/__init__.py new file mode 100644 index 00000000..427b7422 --- /dev/null +++ b/src/flamingo/jobs/__init__.py @@ -0,0 +1,11 @@ +from .base_config import BaseJobConfig +from .evaluation_config import EvaluationJobConfig +from .finetuning_config import FinetuningJobConfig +from .simple_config import SimpleJobConfig + +__all__ = [ + "BaseJobConfig", + "SimpleJobConfig", + "FinetuningJobConfig", + "EvaluationJobConfig", +] diff --git a/src/flamingo/jobs/base_config.py b/src/flamingo/jobs/base_config.py new file mode 100644 index 00000000..3fec2edc --- /dev/null +++ b/src/flamingo/jobs/base_config.py @@ -0,0 +1,7 @@ +from flamingo.types import BaseFlamingoConfig + + +class BaseJobConfig(BaseFlamingoConfig): + """Configuration defining a job to submit to the Ray cluster.""" + + pass diff --git a/src/flamingo/jobs/evaluation_config.py b/src/flamingo/jobs/evaluation_config.py new file mode 100644 index 00000000..a2d8c85b --- /dev/null +++ b/src/flamingo/jobs/evaluation_config.py @@ -0,0 +1,118 @@ +import datetime +from dataclasses import InitVar +from pathlib import Path +from typing import Any + +from pydantic import root_validator, validator +from pydantic.dataclasses import dataclass + +from flamingo.integrations.huggingface import QuantizationConfig +from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name +from flamingo.integrations.wandb import WandbEnvironment, WandbEnvironmentMixin +from flamingo.jobs import BaseJobConfig +from flamingo.types import SerializableTorchDtype + + +@dataclass +class ModelNameOrCheckpointPath: + """ + This class is explicitly used to validate if a string is + a valid HuggingFace model or can be used as a checkpoint. + + Checkpoint will be automatically assigned if it's a valid checkpoint; + it will be None if it's not valid. + """ + + # explictly needed for matching + __match_args__ = ("name", "checkpoint") + + name: str + checkpoint: InitVar[str | None] = None + + def __post_init__(self, checkpoint): + if isinstance(self.name, Path): + self.name = str(self.name) + + if Path(self.name).is_absolute(): + self.checkpoint = self.name + else: + self.checkpoint = None + + if self.checkpoint is None and not is_valid_huggingface_model_name(self.name): + raise (ValueError(f"{self.name} is not a valid checkpoint path or HF model name")) + + +class EvaluationJobConfig(WandbEnvironmentMixin, BaseJobConfig): + """Configuration to run an lm-evaluation-harness evaluation job. + + This job loads an existing checkpoint path + from Ray storage to run evaluation against, OR a huggingface Model. + and logs the evaluation results to W&B. + When a W&B config is specified, the job will attempt to resolve a checkpoint path + associated with that run. + + This can be manually overwritten by specifying the `model_name_or_path` variable + which will take prescedence over the W&B checkpoint path. + """ + + class Config: + validate_assignment = True + + tasks: list[str] + model_name_or_path: str | Path | ModelNameOrCheckpointPath | None = None + batch_size: int | None = None + num_fewshot: int | None = None + limit: int | float | None = None + trust_remote_code: bool = False + torch_dtype: SerializableTorchDtype = None + quantization_config: QuantizationConfig | None = None + num_cpus: int = 1 + num_gpus: int = 1 + timeout: datetime.timedelta | None = None + + @validator("model_name_or_path", pre=True, always=True) + def _validate_modelname(cls, v): # noqa: N805 + """ + happens pre-validation and makes sure we correctly set a typed var for matching. + """ + if isinstance(v, dict): + return ModelNameOrCheckpointPath(**v) + elif v is None: + return None + else: + return ModelNameOrCheckpointPath(name=v) + + @root_validator(pre=False) + @classmethod + def _validate_modelname_or_checkpoint(cls, values) -> Any: + """ + Primarily logic to infer if a passed value is a HuggingFace model, + checkpoint for resuming, or not. + """ + mnp = values.get("model_name_or_path") + wandb_env = values.get("wandb_env") + + # fairly complex logic here: + # we're matching on the structure of the passed args + + match (mnp, wandb_env): + case (None, None): + raise (ValueError("Either `model_name_or_path` or `wandb_env` must be provided.")) + + case (None, WandbEnvironment()): + print( + "no model name or checkpoint passed; will attempt to run from passed wandb run." + ) + case (ModelNameOrCheckpointPath() as x, _): + print( + "will ignore passed information from a wandb run, " + f"if present, and prefer loading from: {x.name}" + ) + case _: + raise (ValueError(f"{mnp} is not a valid HuggingFaceModel or checkpoint path.")) + + return values + + @property + def entrypoint_command(self) -> str: + return f"python run_evaluation.py --config_json '{self.json()}'" diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py new file mode 100644 index 00000000..51a733e4 --- /dev/null +++ b/src/flamingo/jobs/finetuning_config.py @@ -0,0 +1,51 @@ +from peft import LoraConfig +from pydantic import validator +from ray.train import ScalingConfig + +from flamingo.integrations.huggingface import QuantizationConfig +from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name +from flamingo.integrations.wandb import WandbEnvironmentMixin +from flamingo.jobs import BaseJobConfig +from flamingo.types import SerializableTorchDtype + + +class FinetuningJobConfig(WandbEnvironmentMixin, BaseJobConfig): + """Configuration to submit an LLM finetuning job.""" + + model: str + dataset: str + tokenizer: str | None = None + # Training + max_seq_length: int | None = None + num_train_epochs: int = 1 + batch_size: int = 16 + learning_rate: float = 1e-5 + weight_decay: float = 1e-3 + gradient_accumulation_steps: int = 1 + gradient_checkpointing: bool = False + trust_remote_code: bool = False + torch_dtype: SerializableTorchDtype = None + # Logging + evaluation_strategy: str = "epoch" + eval_steps: float | None = None + logging_strategy: str = "steps" + logging_steps: float = 100 + save_strategy: str = "steps" + save_steps: int = 500 + # Lora/quantization + lora_config: LoraConfig | None = None # TODO: Create our own config type + quantization_config: QuantizationConfig | None = None + # Cluster + storage_path: str | None = None + scaling_config: ScalingConfig | None = None # TODO: Create our own config type + + @validator("model") + def _validate_modelname(cls, v): # noqa: N805 + if is_valid_huggingface_model_name(v): + return v + else: + raise (ValueError(f"`{v}` is not a valid HuggingFace model name.")) + + @property + def entrypoint_command(self) -> str: + return f"python run_finetuning.py --config_json '{self.json()}'" diff --git a/src/flamingo/jobs/simple_config.py b/src/flamingo/jobs/simple_config.py new file mode 100644 index 00000000..04cb270b --- /dev/null +++ b/src/flamingo/jobs/simple_config.py @@ -0,0 +1,15 @@ +from flamingo.jobs import BaseJobConfig + + +class SimpleJobConfig(BaseJobConfig): + """A simple job to demonstrate the submission interface.""" + + magic_number: int + + @property + def env_vars(self) -> dict[str, str]: + return {} + + @property + def entrypoint_command(self) -> str: + return f"python simple.py --magic_number '{self.magic_number}'" diff --git a/src/flamingo/types.py b/src/flamingo/types.py new file mode 100644 index 00000000..9163a1e9 --- /dev/null +++ b/src/flamingo/types.py @@ -0,0 +1,46 @@ +from pathlib import Path +from typing import Any + +import torch +from pydantic import BaseModel, Extra, validator +from pydantic.fields import ModelField +from pydantic_yaml import parse_yaml_file_as, to_yaml_file + +SerializableTorchDtype = str | torch.dtype | None + + +class BaseFlamingoConfig(BaseModel): + """Base class for all Pydnatic configs in the library. + + Defines some common settings used by all subclasses. + """ + + class Config: + extra = Extra.forbid + arbitrary_types_allowed = True + validate_assignment = True + json_encoders = { + # Default JSON encoding of a torch.dtype object + # Defining here allows it to be inherited by all sub-classes of BaseFlamingoConfig + torch.dtype: lambda x: str(x).split(".")[1], + } + + @validator("*", pre=True) + def validate_serializable_dtype(cls, x: Any, field: ModelField) -> Any: # noqa: N805 + """Extract the torch.dtype corresponding to a string value, else return the value. + + This is a Pydantic-specific construct that is run on all fields + before other default validations. + + Inspired by the HuggingFace `BitsAndBytesConfig` logic. + """ + if field.type_ is SerializableTorchDtype and isinstance(x, str): + return getattr(torch, x) + return x + + @classmethod + def from_yaml_file(cls, path: Path | str): + return parse_yaml_file_as(cls, path) + + def to_yaml_file(self, path: Path | str): + to_yaml_file(path, self, exclude_none=True) diff --git a/src/flamingo/utils.py b/src/flamingo/utils.py new file mode 100644 index 00000000..089c8fd9 --- /dev/null +++ b/src/flamingo/utils.py @@ -0,0 +1,11 @@ +import os +from datetime import datetime + +__all__ = ["get_default_run_name"] + + +def get_default_run_name(base_name="dummy_run"): + if user := os.getenv("USER", None): + return f"{user}_{base_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" + else: + return f"{base_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..5beb3e73 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,155 @@ +""" +Tests for the LLM tuner. This file is used to provide fixtures for the test session +that are accessible to all submodules. +""" +import dataclasses +import os +from collections.abc import Generator +from unittest import mock + +import pytest +import tuner +import wandb +from tuner.integrations.wandb import WandbEnvironment +from tuner.jobs.evaluation_config import EvaluationJobConfig +from wandb.sdk.lib.runid import generate_id + + +@dataclasses.dataclass +class Project: + name: str + + +@pytest.fixture(autouse=True, scope="function") +def mock_wandb_api_call(): + with mock.patch( + "tuner.integrations.wandb.utils.wandb.Api.projects", return_value=[Project("my-project")] + ) as p: + yield p + + +### from wandb +def dict_factory(): + def helper(): + return dict() + + return helper + + +@pytest.fixture(scope="function") +def test_settings(): + """ + taken from wandb to generate a test Settings instance. + """ + + def update_test_settings( + extra_settings: dict | wandb.sdk.wandb_settings.Settings = dict_factory(), # noqa: B008 + ): + settings = wandb.Settings( + console="off", + save_code=False, + ) + if isinstance(extra_settings, dict): + settings.update(extra_settings, source=wandb.sdk.wandb_settings.Source.BASE) + elif isinstance(extra_settings, wandb.sdk.wandb_settings.Settings): + settings.update(extra_settings) + settings._set_run_start_time() + return settings + + yield update_test_settings + + +@pytest.fixture(scope="function") +def mock_run(test_settings): + """ + taken from wandb to generate a test Run instance. + """ + from wandb.sdk.lib.module import unset_globals + + def mock_run_fn(**kwargs) -> "wandb.sdk.wandb_run.Run": + kwargs_settings = kwargs.pop("settings", dict()) + kwargs_settings = { + **{ + "run_id": generate_id(), + }, + **kwargs_settings, + } + run = wandb.wandb_sdk.wandb_run.Run(settings=test_settings(kwargs_settings), **kwargs) + run._set_backend(mock.MagicMock()) + run._set_globals() + return run + + yield mock_run_fn + unset_globals() + + +@pytest.fixture(autouse=True, scope="function") +def mock_valid_run(mock_run): + """ + taken from wandb to generate a valid run. + """ + run = mock_run() + with mock.patch("tuner.integrations.wandb.utils._resolve_wandb_run", return_value=run) as r: + yield r + + +@pytest.fixture(autouse=True, scope="function") +def mock_environment_with_keys(): + """Mocks an API key-like mechanism for the environment.""" + with mock.patch.dict(os.environ, {"WANDB_API_KEY": "abcdefg123"}): + yield + + +@pytest.fixture(autouse=True, scope="function") +def mock_environment_without_keys(): + """Mocks an environment missing common API keys.""" + with mock.patch.dict(os.environ, clear=True): + yield + + +@pytest.fixture(scope="function") +def mock_wandb_env(mock_run, test_settings) -> Generator[WandbEnvironment, None, None]: + """ + Sets up a mock wandb_env object. + """ + + def mock_env_func(**kwargs) -> "tuner.integrations.wandb.WandbEnvironment": + mine = { + "name": "my-run", + "project": "my-project", + "entity": "mozilla-ai", + "run_id": "gabbagool-123", + } + kwargs = {**mine, **kwargs} + return WandbEnvironment(**kwargs) + + yield mock_env_func + + +@pytest.fixture(scope="session") +def checkpoint_path(tmp_path_factory): + "makes a mocked checkpoint path dir." + fn = tmp_path_factory.mktemp("data") / "model" + return fn + + +@pytest.fixture(scope="function") +def default_eval_config(mock_run, test_settings): + """ + Sets up a default + """ + + def default_eval_config(**kwargs) -> "tuner.integrations.wandb.WandbEnvironment": + mine = { + "tasks": ["task1", "task2"], + "model_name_or_path": None, + "wandb_env": None, + "num_fewshot": 5, + "batch_size": 16, + "torch_dtype": "bfloat16", + "quantization_config": None, + "timeout": 3600, + } + return EvaluationJobConfig(**{**mine, **kwargs}) + + yield default_eval_config diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integrations/huggingface/__init__.py b/tests/integrations/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integrations/huggingface/test_quantization_config.py b/tests/integrations/huggingface/test_quantization_config.py new file mode 100644 index 00000000..c52a4816 --- /dev/null +++ b/tests/integrations/huggingface/test_quantization_config.py @@ -0,0 +1,16 @@ +import pytest +import torch +from tuner.integrations.huggingface import QuantizationConfig + + +@pytest.fixture +def quantization_config() -> QuantizationConfig: + return QuantizationConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_quant_type="nf4", + ) + + +def test_serde_round_trip(quantization_config: QuantizationConfig): + assert QuantizationConfig.parse_raw(quantization_config.json()) == quantization_config diff --git a/tests/integrations/wandb/__init__.py b/tests/integrations/wandb/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integrations/wandb/test_wandb_environment.py b/tests/integrations/wandb/test_wandb_environment.py new file mode 100644 index 00000000..c347d955 --- /dev/null +++ b/tests/integrations/wandb/test_wandb_environment.py @@ -0,0 +1,26 @@ +import pytest +from pydantic import ValidationError +from tuner.integrations.wandb import WandbEnvironment + + +def test_env_vars(mock_wandb_env): + env_vars = mock_wandb_env().env_vars + expected = ["WANDB_NAME", "WANDB_PROJECT", "WANDB_ENTITY", "WANDB_RUN_ID"] + for key in expected: + assert key in env_vars + assert "WANDB_RUN_GROUP" not in env_vars + + +def test_serde_round_trip(mock_wandb_env): + assert WandbEnvironment.parse_raw(mock_wandb_env().json()) == mock_wandb_env() + + +def test_disallowed_kwargs(): + with pytest.raises(ValidationError): + WandbEnvironment(name="name", project="project", old_name="I will throw") + + +def test_missing_key_warning(mock_environment_without_keys): + with pytest.warns(UserWarning): + env = WandbEnvironment(name="I am missing an API key", project="I should warn the user") + assert "WANDB_API_KEY" not in env.env_vars diff --git a/tests/jobs/__init__.py b/tests/jobs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/jobs/test_evaluation_config.py b/tests/jobs/test_evaluation_config.py new file mode 100644 index 00000000..debd8eb7 --- /dev/null +++ b/tests/jobs/test_evaluation_config.py @@ -0,0 +1,81 @@ +from pathlib import Path + +import pytest +from pydantic import ValidationError +from tuner.jobs.evaluation_config import EvaluationJobConfig + + +def test_bad_hf_name(default_eval_config): + with pytest.raises(ValidationError): + default_eval_config( + model_name_or_path="dfa../invalid", + ) + + +def test_serde_round_trip_path_notmp(mock_wandb_env, default_eval_config): + config = default_eval_config( + model_name_or_path=Path("test"), + wandb_env=None, + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_serde_round_trip_checkpointpath_only(mock_wandb_env, checkpoint_path, default_eval_config): + config = default_eval_config( + model_name_or_path=checkpoint_path, + wandb_env=None, + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_serde_round_trip_wandb_and_path(mock_wandb_env, checkpoint_path, default_eval_config): + config = default_eval_config( + model_name_or_path=checkpoint_path, + wandb_env=mock_wandb_env(), + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_serde_round_trip_wandb(mock_wandb_env, default_eval_config): + config = default_eval_config( + model_name_or_path=None, + wandb_env=mock_wandb_env(), + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_serde_round_trip_with_checkpoint_path_no_wandb(default_eval_config): + config = default_eval_config( + model_name_or_path="/fake_path/to/a/file", + wandb_env=None, + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_serde_round_trip_no_wandb(default_eval_config): + config = default_eval_config( + model_name_or_path="some/model", + wandb_env=None, + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_model_validation(): + with pytest.raises(ValidationError): + # Neither checkpoint_path or wandb_env specified + EvaluationJobConfig(tasks=["task1", "task2"], num_fewshot=5, batch_size=16) + + +def test_serde_round_trip_default_config(default_eval_config): + config = default_eval_config( + model_name_or_path="fake_path", + wandb_env=None, + ) + assert EvaluationJobConfig.parse_raw(config.json()) == config + + +def test_parse_from_yaml(default_eval_config, tmp_path_factory): + config = default_eval_config(model_name_or_path="not_a_real_model") + p = tmp_path_factory.mktemp("test_yaml") / "eval.yaml" + config.to_yaml_file(p) + assert config == EvaluationJobConfig.from_yaml_file(p) diff --git a/tests/jobs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py new file mode 100644 index 00000000..bdbf68ba --- /dev/null +++ b/tests/jobs/test_finetuning_config.py @@ -0,0 +1,24 @@ +from peft import LoraConfig +from ray.train import ScalingConfig +from tuner.datasets.dataset_choice import DatasetChoice +from tuner.integrations.huggingface import QuantizationConfig +from tuner.integrations.wandb import WandbEnvironment +from tuner.jobs import FinetuningJobConfig + + +def test_serde_round_trip(): + wandb_env = WandbEnvironment(name="my-run", project="my-project") + lora_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM") + quantization_config = QuantizationConfig(load_in_8bit=True) + scaling_config = ScalingConfig(num_workers=2, use_gpu=True) + config = FinetuningJobConfig( + model="test-model", + dataset=DatasetChoice.Dolly, + torch_dtype="bfloat16", + wandb_env=wandb_env, + lora_config=lora_config, + quantization_config=quantization_config, + scaling_config=scaling_config, + storage_path="/mnt/data/ray_results", + ) + assert FinetuningJobConfig.parse_raw(config.json()) == config