From d726bfcb1507b97075e82a2cfc805fb6cb4f3063 Mon Sep 17 00:00:00 2001 From: Sean Friedowitz Date: Tue, 16 Jan 2024 18:56:15 -0800 Subject: [PATCH] copy over new config classes --- .../huggingface/dataset_config.py | 16 ++++ .../integrations/huggingface/model_config.py | 19 +++++ .../huggingface/model_name_or_path.py | 34 --------- .../huggingface/tokenizer_config.py | 25 +++++++ .../huggingface/trainer_config.py | 54 +++++++++---- .../integrations/huggingface/utils.py | 26 ++++++- src/flamingo/integrations/wandb/__init__.py | 6 +- .../integrations/wandb/artifact_link.py | 17 +++++ .../{wandb_environment.py => run_link.py} | 63 +++++++++------- src/flamingo/integrations/wandb/utils.py | 38 +++------- src/flamingo/jobs/base_config.py | 7 -- src/flamingo/jobs/finetuning_config.py | 75 ++++++++++++++----- src/flamingo/jobs/lm_harness_config.py | 49 +++++------- src/flamingo/jobs/simple_config.py | 4 +- tests/conftest.py | 6 +- .../wandb/test_wandb_environment.py | 8 +- tests/{configs => jobs}/__init__.py | 0 .../test_finetuning_config.py | 0 .../test_lm_harness_config.py | 0 19 files changed, 275 insertions(+), 172 deletions(-) create mode 100644 src/flamingo/integrations/huggingface/dataset_config.py create mode 100644 src/flamingo/integrations/huggingface/model_config.py delete mode 100644 src/flamingo/integrations/huggingface/model_name_or_path.py create mode 100644 src/flamingo/integrations/huggingface/tokenizer_config.py create mode 100644 src/flamingo/integrations/wandb/artifact_link.py rename src/flamingo/integrations/wandb/{wandb_environment.py => run_link.py} (54%) delete mode 100644 src/flamingo/jobs/base_config.py rename tests/{configs => jobs}/__init__.py (100%) rename tests/{configs => jobs}/test_finetuning_config.py (100%) rename tests/{configs => jobs}/test_lm_harness_config.py (100%) diff --git a/src/flamingo/integrations/huggingface/dataset_config.py b/src/flamingo/integrations/huggingface/dataset_config.py new file mode 100644 index 00000000..e31014eb --- /dev/null +++ b/src/flamingo/integrations/huggingface/dataset_config.py @@ -0,0 +1,16 @@ +from pydantic import validator + +from flamingo.integrations.huggingface.utils import repo_id_validator +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig + + +class DatasetConfig(BaseFlamingoConfig): + """Settings passed to load a HuggingFace dataset.""" + + path: str | WandbArtifactLink + split: str | None = None + test_size: float | None = None + seed: int | None = None + + _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) diff --git a/src/flamingo/integrations/huggingface/model_config.py b/src/flamingo/integrations/huggingface/model_config.py new file mode 100644 index 00000000..60f54077 --- /dev/null +++ b/src/flamingo/integrations/huggingface/model_config.py @@ -0,0 +1,19 @@ +from pydantic import validator + +from flamingo.integrations.huggingface.utils import repo_id_validator +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype + + +class AutoModelConfig(BaseFlamingoConfig): + """Settings passed to a HuggingFace AutoModel instantiation. + + The model path can either be a string corresponding to a HuggingFace repo ID, + or an artifact link to a reference artifact on W&B. + """ + + path: str | WandbArtifactLink + trust_remote_code: bool = False + torch_dtype: SerializableTorchDtype = None + + _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) diff --git a/src/flamingo/integrations/huggingface/model_name_or_path.py b/src/flamingo/integrations/huggingface/model_name_or_path.py deleted file mode 100644 index f51a69dd..00000000 --- a/src/flamingo/integrations/huggingface/model_name_or_path.py +++ /dev/null @@ -1,34 +0,0 @@ -from pathlib import Path - -from pydantic.dataclasses import dataclass - -from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name - - -@dataclass -class ModelNameOrCheckpointPath: - """ - This class is explicitly used to validate if a string is - a valid HuggingFace model or can be used as a checkpoint. - - Checkpoint will be automatically assigned if it's a valid checkpoint; - it will be None if it's not valid. - """ - - # explictly needed for matching - __match_args__ = ("name", "checkpoint") - - name: str - checkpoint: str | None = None - - def __post_init__(self): - if isinstance(self.name, Path): - self.name = str(self.name) - - if Path(self.name).is_absolute(): - self.checkpoint = self.name - else: - self.checkpoint = None - - if self.checkpoint is None and not is_valid_huggingface_model_name(self.name): - raise ValueError(f"{self.name} is not a valid checkpoint path or HF model name.") diff --git a/src/flamingo/integrations/huggingface/tokenizer_config.py b/src/flamingo/integrations/huggingface/tokenizer_config.py new file mode 100644 index 00000000..3f16b342 --- /dev/null +++ b/src/flamingo/integrations/huggingface/tokenizer_config.py @@ -0,0 +1,25 @@ +from typing import Any + +from pydantic import validator + +from flamingo.integrations.huggingface.utils import repo_id_validator +from flamingo.integrations.wandb import WandbArtifactLink +from flamingo.types import BaseFlamingoConfig + + +class AutoTokenizerConfig(BaseFlamingoConfig): + """Settings passed to a HuggingFace AutoTokenizer instantiation.""" + + path: str | WandbArtifactLink + trust_remote_code: bool | None = None + use_fast: bool | None = None + + _path_validator = validator("path", allow_reuse=True, pre=True)(repo_id_validator) + + def get_tokenizer_args(self) -> dict[str, Any]: + args = dict( + trust_remote_code=self.trust_remote_code, + use_fast=self.use_fast, + ) + # Only return non-None values so we get HuggingFace defaults when not specified + return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/trainer_config.py b/src/flamingo/integrations/huggingface/trainer_config.py index e78e1391..e3cf4b64 100644 --- a/src/flamingo/integrations/huggingface/trainer_config.py +++ b/src/flamingo/integrations/huggingface/trainer_config.py @@ -1,21 +1,45 @@ -from flamingo.types import BaseFlamingoConfig, SerializableTorchDtype +from typing import Any + +from flamingo.types import BaseFlamingoConfig class TrainerConfig(BaseFlamingoConfig): - """Configuration for a HuggingFace trainer/training arguments.""" + """Configuration for a HuggingFace trainer/training arguments. + + This mainly encompasses arguments passed to the HuggingFace `TrainingArguments` class, + but also contains some additional parameters for the `Trainer` or `SFTTrainer` classes. + """ max_seq_length: int | None = None - num_train_epochs: int = 1 - batch_size: int = 16 - learning_rate: float = 1e-5 - weight_decay: float = 1e-3 - gradient_accumulation_steps: int = 1 - gradient_checkpointing: bool = False - trust_remote_code: bool = False - torch_dtype: SerializableTorchDtype = None - evaluation_strategy: str = "epoch" + num_train_epochs: int | None = None + per_device_train_batch_size: int | None = None + per_device_eval_batch_size: int | None = None + learning_rate: float | None = None + weight_decay: float | None = None + gradient_accumulation_steps: int | None = None + gradient_checkpointing: bool | None = None + evaluation_strategy: str | None = None eval_steps: float | None = None - logging_strategy: str = "steps" - logging_steps: float = 100 - save_strategy: str = "steps" - save_steps: int = 500 + logging_strategy: str | None = None + logging_steps: float | None = None + save_strategy: str | None = None + save_steps: int | None = None + + def get_training_args(self) -> dict[str, Any]: + args = dict( + num_train_epochs=self.num_train_epochs, + learning_rate=self.learning_rate, + per_device_train_batch_size=self.per_device_train_batch_size, + per_device_eval_batch_size=self.per_device_eval_batch_size, + gradient_accumulation_steps=self.gradient_accumulation_steps, + gradient_checkpointing=self.gradient_checkpointing, + weight_decay=self.weight_decay, + evaluation_strategy=self.evaluation_strategy, + eval_steps=self.eval_steps, + logging_strategy=self.logging_strategy, + logging_steps=self.logging_steps, + save_strategy=self.save_strategy, + save_steps=self.save_steps, + ) + # Only return non-None values so we use the HuggingFace defaults when not specified + return {k: v for k, v in args.items() if v is not None} diff --git a/src/flamingo/integrations/huggingface/utils.py b/src/flamingo/integrations/huggingface/utils.py index 21d4ce54..7622e3c8 100644 --- a/src/flamingo/integrations/huggingface/utils.py +++ b/src/flamingo/integrations/huggingface/utils.py @@ -1,7 +1,16 @@ +from typing import Any + +from datasets import DatasetDict, load_dataset from huggingface_hub.utils import HFValidationError, validate_repo_id -def is_valid_huggingface_model_name(s: str): +def repo_id_validator(x: Any): + if isinstance(x, str) and not is_valid_huggingface_repo_id(x): + raise ValueError(f"{x} is not a valid HuggingFace repo ID.") + return x + + +def is_valid_huggingface_repo_id(s: str): """ Simple test to check if an HF model is valid using HuggingFace's tools. Sadly, theirs throws an exception and has no return. @@ -14,3 +23,18 @@ def is_valid_huggingface_model_name(s: str): return True except HFValidationError: return False + + +def load_and_split_dataset( + path: str, + *, + split: str | None = None, + test_size: float | None, + seed: int | None = None, +) -> DatasetDict: + dataset = load_dataset(path, split=split) + if test_size is not None: + datasets = dataset.train_test_split(test_size=test_size, seed=seed) + else: + datasets = DatasetDict({"train": dataset}) + return datasets diff --git a/src/flamingo/integrations/wandb/__init__.py b/src/flamingo/integrations/wandb/__init__.py index 8e7c8a52..44b455d0 100644 --- a/src/flamingo/integrations/wandb/__init__.py +++ b/src/flamingo/integrations/wandb/__init__.py @@ -1,8 +1,10 @@ -from .wandb_environment import WandbEnvironment # noqa: I001 +from .artifact_link import WandbArtifactLink +from .run_link import WandbRunLink from .utils import get_wandb_summary, update_wandb_summary __all__ = [ - "WandbEnvironment", + "WandbArtifactLink", + "WandbRunLink", "get_wandb_summary", "update_wandb_summary", ] diff --git a/src/flamingo/integrations/wandb/artifact_link.py b/src/flamingo/integrations/wandb/artifact_link.py new file mode 100644 index 00000000..1f3c9d53 --- /dev/null +++ b/src/flamingo/integrations/wandb/artifact_link.py @@ -0,0 +1,17 @@ +from flamingo.types import BaseFlamingoConfig + + +class WandbArtifactLink(BaseFlamingoConfig): + """Data required to retrieve an artifact from W&B.""" + + name: str + version: str = "latest" + project: str | None = None + entity: str | None = None + + @property + def wandb_path(self) -> str: + """String identifier for the asset on the W&B platform.""" + path = "/".join(x for x in [self.entity, self.project, self.name] if x is not None) + path = f"{path}:{self.version}" + return path diff --git a/src/flamingo/integrations/wandb/wandb_environment.py b/src/flamingo/integrations/wandb/run_link.py similarity index 54% rename from src/flamingo/integrations/wandb/wandb_environment.py rename to src/flamingo/integrations/wandb/run_link.py index 24d3eedc..c6d168c7 100644 --- a/src/flamingo/integrations/wandb/wandb_environment.py +++ b/src/flamingo/integrations/wandb/run_link.py @@ -1,33 +1,30 @@ import os import warnings -from pydantic import Extra, root_validator +from pydantic import root_validator from wandb.apis.public import Run +from wandb.util import random_string from flamingo.types import BaseFlamingoConfig -class WandbEnvironment(BaseFlamingoConfig): +class WandbRunLink(BaseFlamingoConfig): """Settings required to log to a W&B run. - The fields on this class map to the environment variables - that are used to control the W&B logging locations. + A W&B Run is uniquely identified by the combination of `entity/project/run_id`. + The W&B platform will auto-generate values for these variables if they are not provided. - The `name` and `project` are required as they are the minimum information - required to identify a run. The `name` is the human-readable name that appears in the W&B UI. - `name` is different than the `run_id` which must be unique within a project. - Although the `name` is not mandatorily unique, it is generally best practice to use a - unique and descriptive name to later identify the run. + However, based on how these attributes are passed between jobs it is often necessary + to know the run ID before initializing a run. + For this reason, the run ID field is made non-optional and auto-generated locally + if it is not provided. """ - class Config: - extra = Extra.forbid # Error on extra kwargs - - __match_args__ = ("name", "project", "run_id", "run_group", "entity") + __match_args__ = ("run_id", "name", "project", "run_group", "entity") + run_id: str name: str | None = None project: str | None = None - run_id: str | None = None run_group: str | None = None entity: str | None = None @@ -40,22 +37,15 @@ def warn_missing_api_key(cls, values): ) return values - @property - def env_vars(self) -> dict[str, str]: - # WandB w/ HuggingFace is weird. You can specify the run name inline, - # but the rest must be injected as environment variables - env_vars = { - "WANDB_NAME": self.name, - "WANDB_PROJECT": self.project, - "WANDB_RUN_ID": self.run_id, - "WANDB_RUN_GROUP": self.run_group, - "WANDB_ENTITY": self.entity, - "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), - } - return {k: v for k, v in env_vars.items() if v is not None} + @root_validator(pre=True) + def ensure_run_id(cls, values): + if values.get("run_id", None) is None: + # Generate an random 8-digit alphanumeric string, analogous to W&B platform + values["run_id"] = random_string(length=8) + return values @classmethod - def from_run(cls, run: Run) -> "WandbEnvironment": + def from_run(cls, run: Run) -> "WandbRunLink": """Extract environment settings from a W&B Run object. Useful when listing runs from the W&B API and extracting their settings for a job. @@ -67,3 +57,20 @@ def from_run(cls, run: Run) -> "WandbEnvironment": entity=run.entity, run_id=run.id, ) + + @property + def wandb_path(self) -> str: + """String identifier for the asset on the W&B platform.""" + path = "/".join(x for x in [self.entity, self.project, self.run_id] if x is not None) + return path + + def get_env_vars(self) -> dict[str, str]: + env_vars = { + "WANDB_RUN_ID": self.run_id, + "WANDB_NAME": self.name, + "WANDB_PROJECT": self.project, + "WANDB_RUN_GROUP": self.run_group, + "WANDB_ENTITY": self.entity, + "WANDB_API_KEY": os.environ.get("WANDB_API_KEY", None), + } + return {k: v for k, v in env_vars.items() if v is not None} diff --git a/src/flamingo/integrations/wandb/utils.py b/src/flamingo/integrations/wandb/utils.py index 1b00d9ea..6ea2afc4 100644 --- a/src/flamingo/integrations/wandb/utils.py +++ b/src/flamingo/integrations/wandb/utils.py @@ -3,39 +3,23 @@ import wandb from wandb.apis.public import Run -from flamingo.integrations.wandb import WandbEnvironment +from flamingo.integrations.wandb import WandbRunLink -def get_wandb_summary(env: WandbEnvironment) -> dict[str, Any]: +def get_wandb_run(env: WandbRunLink) -> Run: + """Retrieve a run from the W&B API.""" + api = wandb.Api() + return api.run(env.wandb_path) + + +def get_wandb_summary(env: WandbRunLink) -> dict[str, Any]: """Get the summary dictionary attached to a W&B run.""" - run = _resolve_wandb_run(env) + run = get_wandb_run(env) return dict(run.summary) -def update_wandb_summary(env: WandbEnvironment, metrics: dict[str, Any]) -> None: +def update_wandb_summary(env: WandbRunLink, metrics: dict[str, Any]) -> None: """Update a run's summary with the provided metrics.""" - run = _resolve_wandb_run(env) + run = get_wandb_run(env) run.summary.update(metrics) run.update() - - -def _resolve_wandb_run(env: WandbEnvironment) -> Run: - """Resolve a WandB run object from the provided environment settings. - - An exception is raised if a Run cannot be found, - or if multiple runs exist in scope with the same name. - """ - api = wandb.Api() - base_path = "/".join(x for x in (env.entity, env.project) if x) - if env.run_id is not None: - full_path = f"{base_path}/{env.run_id}" - return api.run(full_path) - else: - match [run for run in api.runs(base_path) if run.name == env.name]: - case []: - raise RuntimeError(f"No WandB runs found at {base_path}/{env.name}") - case [Run(), _]: - raise RuntimeError(f"Multiple WandB runs found at {base_path}/{env.name}") - case [Run()] as mr: - # we have a single one, hurray - return mr[0] diff --git a/src/flamingo/jobs/base_config.py b/src/flamingo/jobs/base_config.py deleted file mode 100644 index 3fec2edc..00000000 --- a/src/flamingo/jobs/base_config.py +++ /dev/null @@ -1,7 +0,0 @@ -from flamingo.types import BaseFlamingoConfig - - -class BaseJobConfig(BaseFlamingoConfig): - """Configuration defining a job to submit to the Ray cluster.""" - - pass diff --git a/src/flamingo/jobs/finetuning_config.py b/src/flamingo/jobs/finetuning_config.py index 1356d95a..40906020 100644 --- a/src/flamingo/jobs/finetuning_config.py +++ b/src/flamingo/jobs/finetuning_config.py @@ -1,28 +1,63 @@ +from typing import Any + from peft import LoraConfig -from pydantic import validator -from ray.train import ScalingConfig +from pydantic import Field, validator + +from flamingo.integrations.huggingface import ( + AutoModelConfig, + AutoTokenizerConfig, + DatasetConfig, + QuantizationConfig, + TrainerConfig, +) +from flamingo.integrations.wandb import WandbRunLink +from flamingo.types import BaseFlamingoConfig + + +class RayTrainConfig(BaseFlamingoConfig): + """Misc settings passed to Ray train. + + Includes information for scaling, checkpointing, and runtime storage. + """ -from flamingo.integrations.huggingface import QuantizationConfig -from flamingo.integrations.huggingface.trainer_config import TrainerConfig -from flamingo.integrations.huggingface.utils import is_valid_huggingface_model_name -from flamingo.jobs import BaseJobConfig + use_gpu: bool = True + num_workers: int | None = None + storage_path: str | None = None + + def get_scaling_args(self) -> dict[str, Any]: + args = dict(use_gpu=self.use_gpu, num_workers=self.num_workers) + return {k: v for k, v in args.items() if v is not None} -class FinetuningJobConfig(BaseJobConfig): +class FinetuningJobConfig(BaseFlamingoConfig): """Configuration to submit an LLM finetuning job.""" - model: str - dataset: str - tokenizer: str | None = None - trainer: TrainerConfig | None = None - lora: LoraConfig | None = None # TODO: Create our own config type + model: AutoModelConfig + dataset: DatasetConfig + tokenizer: AutoTokenizerConfig | None = None quantization: QuantizationConfig | None = None - scaling: ScalingConfig | None = None # TODO: Create our own config type - storage_path: str | None = None + adapter: LoraConfig | None = None # TODO: Create own dataclass here + tracking: WandbRunLink | None = None + trainer: TrainerConfig = Field(default_factory=TrainerConfig) + ray: RayTrainConfig = Field(default_factory=RayTrainConfig) + + @validator("model", pre=True, always=True) + def validate_model_arg(cls, x): + """Allow for passing just a path string as the model argument.""" + if isinstance(x, str): + return AutoModelConfig(path=x) + return x + + @validator("dataset", pre=True, always=True) + def validate_dataset_arg(cls, x): + """Allow for passing just a path string as the dataset argument.""" + if isinstance(x, str): + return DatasetConfig(path=x) + return x - @validator("model") - def _validate_model_name(cls, v): - if is_valid_huggingface_model_name(v): - return v - else: - raise ValueError(f"`{v}` is not a valid HuggingFace model name.") + @validator("tokenizer", pre=True, always=True) + def validate_tokenizer_arg(cls, x): + """Allow for passing just a path string as the tokenizer argument.""" + if isinstance(x, str): + return AutoTokenizerConfig(name_or_artifact=x) + return x diff --git a/src/flamingo/jobs/lm_harness_config.py b/src/flamingo/jobs/lm_harness_config.py index 905c1078..0f5332e5 100644 --- a/src/flamingo/jobs/lm_harness_config.py +++ b/src/flamingo/jobs/lm_harness_config.py @@ -1,43 +1,34 @@ import datetime -from pathlib import Path -from pydantic import validator +from pydantic import Field -from flamingo.integrations.huggingface import ModelNameOrCheckpointPath, QuantizationConfig -from flamingo.jobs import BaseJobConfig -from flamingo.types import SerializableTorchDtype +from flamingo.integrations.huggingface import AutoModelConfig, QuantizationConfig +from flamingo.integrations.wandb import WandbRunLink +from flamingo.types import BaseFlamingoConfig -class LMHarnessJobConfig(BaseJobConfig): - """Configuration to run an lm-evaluation-harness evaluation job. +class RayComputeSettings(BaseFlamingoConfig): + """Misc settings for Ray compute in the LM harness job.""" - This job loads an existing checkpoint path from Ray storage to run evaluation against, - OR a huggingface Model and logs the evaluation results to W&B. + use_gpu: bool = True + num_workers: int = 1 + timeout: datetime.timedelta | None = None - This can be manually overwritten by specifying the `model_name_or_path` variable - which will take prescedence over the W&B checkpoint path. - """ - class Config: - validate_assignment = True +class LMHarnessEvaluatorSettings(BaseFlamingoConfig): + """Misc settings provided to an lm-harness evaluation job.""" tasks: list[str] batch_size: int | None = None num_fewshot: int | None = None limit: int | float | None = None - trust_remote_code: bool = False - torch_dtype: SerializableTorchDtype = None - model_name_or_path: str | Path | ModelNameOrCheckpointPath | None = None - quantization: QuantizationConfig | None = None - num_cpus: int = 1 - num_gpus: int = 1 - timeout: datetime.timedelta | None = None - @validator("model_name_or_path", pre=True, always=True) - def _validate_model_name_or_path(cls, v): - if isinstance(v, dict): - return ModelNameOrCheckpointPath(**v) - elif v is None: - return None - else: - return ModelNameOrCheckpointPath(name=v) + +class LMHarnessJobConfig(BaseFlamingoConfig): + """Configuration to run an lm-evaluation-harness evaluation job.""" + + model: AutoModelConfig + evaluator: LMHarnessEvaluatorSettings + quantization: QuantizationConfig | None = None + tracking: WandbRunLink | None = None + ray: RayComputeSettings = Field(default_factory=RayComputeSettings) diff --git a/src/flamingo/jobs/simple_config.py b/src/flamingo/jobs/simple_config.py index bc18fa96..24fa60a8 100644 --- a/src/flamingo/jobs/simple_config.py +++ b/src/flamingo/jobs/simple_config.py @@ -1,5 +1,5 @@ -from flamingo.jobs import BaseJobConfig +from flamingo.types import BaseFlamingoConfig -class SimpleJobConfig(BaseJobConfig): +class SimpleJobConfig(BaseFlamingoConfig): magic_number: int diff --git a/tests/conftest.py b/tests/conftest.py index ccd5abfd..c6630295 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ import pytest -from flamingo.integrations.wandb import WandbEnvironment +from flamingo.integrations.wandb import WandbRunLink from flamingo.jobs import LMHarnessJobConfig @@ -28,14 +28,14 @@ def mock_environment_without_keys(): @pytest.fixture(scope="function") def default_wandb_env(): - def generator(**kwargs) -> WandbEnvironment: + def generator(**kwargs) -> WandbRunLink: mine = { "name": "my-run", "project": "my-project", "entity": "mozilla-ai", "run_id": "gabbagool-123", } - return WandbEnvironment(**{**mine, **kwargs}) + return WandbRunLink(**{**mine, **kwargs}) yield generator diff --git a/tests/integrations/wandb/test_wandb_environment.py b/tests/integrations/wandb/test_wandb_environment.py index 69e815e8..55cfd397 100644 --- a/tests/integrations/wandb/test_wandb_environment.py +++ b/tests/integrations/wandb/test_wandb_environment.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from flamingo.integrations.wandb import WandbEnvironment +from flamingo.integrations.wandb import WandbRunLink def test_env_vars(default_wandb_env): @@ -13,15 +13,15 @@ def test_env_vars(default_wandb_env): def test_serde_round_trip(default_wandb_env): - assert WandbEnvironment.parse_raw(default_wandb_env().json()) == default_wandb_env() + assert WandbRunLink.parse_raw(default_wandb_env().json()) == default_wandb_env() def test_disallowed_kwargs(): with pytest.raises(ValidationError): - WandbEnvironment(name="name", project="project", old_name="I will throw") + WandbRunLink(name="name", project="project", old_name="I will throw") def test_missing_key_warning(mock_environment_without_keys): with pytest.warns(UserWarning): - env = WandbEnvironment(name="I am missing an API key", project="I should warn the user") + env = WandbRunLink(name="I am missing an API key", project="I should warn the user") assert "WANDB_API_KEY" not in env.env_vars diff --git a/tests/configs/__init__.py b/tests/jobs/__init__.py similarity index 100% rename from tests/configs/__init__.py rename to tests/jobs/__init__.py diff --git a/tests/configs/test_finetuning_config.py b/tests/jobs/test_finetuning_config.py similarity index 100% rename from tests/configs/test_finetuning_config.py rename to tests/jobs/test_finetuning_config.py diff --git a/tests/configs/test_lm_harness_config.py b/tests/jobs/test_lm_harness_config.py similarity index 100% rename from tests/configs/test_lm_harness_config.py rename to tests/jobs/test_lm_harness_config.py