Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trigger more meaningful validation error messages when trainer registration failed #40

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sfc-gh-lmerrick
Copy link
Collaborator

If a user fails to register their custom Trainer class, or they accidentally supply the incorrect trainer type in the trainer config, validation can proceed in a difficult-to-debug manner. This PR introduces more explicit error messages that trigger in this case and can help the user understand the root cause of the issue faster.

Copy link
Collaborator

@sfc-gh-caxu sfc-gh-caxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not obvious to me whether this change is needed because the error has been raised from get_registered_trainer

Comment on lines -133 to +125
raise ValueError(f"{trainer_name} is not a registered Trainer.")
raise KeyError(f"{trainer_name} is not a registered Trainer.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be ValueError. ValueError suggests the function (get_registered_trainer) receives an invalid value. https://docs.python.org/3/library/exceptions.html#ValueError

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the registry not a key-value mapping? If it's a key-value mapping, I believe KeyError is the more specific, and thus more useful, exception to raise.

Comment on lines +195 to +196
except KeyError as e:
raise KeyError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ValueError, see explanations. I think KeyError is rarely used and only for a local data type (e.g., a dict, set etc.)

@sfc-gh-lmerrick
Copy link
Collaborator Author

Not obvious to me whether this change is needed because the error has been raised from get_registered_trainer

In my tests, the rust-implemented SchemaValidator.validate_python function driving the valiation suppressed this error and passed an info: ValidationInfo object to the parse_sub_config field validator method that was simply missing the extra fields that should have been added by the tried-but-failed previous calls to this function (i.e. no data object).

@sfc-gh-lmerrick
Copy link
Collaborator Author

Added a test. The test fails before the changes, but succeeds after.

Previous test failure output:

========================================================================================= test session starts ==========================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
rootdir: /scratch/tests
configfile: pytest.ini
plugins: devtools-0.12.2, anyio-4.3.0, hypothesis-5.35.1, flakefinder-1.1.0, rerunfailures-14.0, shard-0.1.2, xdist-3.6.1, xdoctest-1.0.2
collected 1 item                                                                                                                                                                                       
Running 1 items in this shard

tests/trainer/test_trainer_validation.py F                                                                                                                                                       [100%]

=============================================================================================== FAILURES ===============================================================================================
______________________________________________________________________________________ test_unregistered_trainer _______________________________________________________________________________________

tmp_path = PosixPath('/tmp/pytest-of-lmerrick/pytest-4/test_unregistered_trainer0')

    @pytest.mark.cpu
    def test_unregistered_trainer(tmp_path):
        config_dict = {
            "type": "unregistered_or_nonexistent",
            "exit_iteration": 2,
            "micro_batch_size": 1,
            "model": {
                "type": "random-weight-hf",
                "name_or_path": "HuggingFaceTB/SmolLM-135M-Instruct",
                "attn_implementation": "eager",
                "dtype": "float32",
            },
            "data": {
                "max_length": 2048,
                "sources": ["HuggingFaceH4/ultrachat_200k-truncated"],
            },
            "deepspeed": {"zero_optimization": {"stage": 0}},
            "optimizer": {"type": "cpu-adam"},
        }
        # Fails in previous implementation of `TrainerConfig.parse_sub_config`, despite
        # the implementation intending for this to succeed.
        with pytest.raises(ValueError) as ctx:
>           config = TrainerConfig(**config_dict)

tests/trainer/test_trainer_validation.py:43: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
arctic_training/config/base.py:25: in __init__
    super().__init__(**data)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cls = <class 'arctic_training.config.trainer.TrainerConfig'>, v = 0
info = ValidationInfo(config={'title': 'TrainerConfig', 'extra_fields_behavior': 'forbid', 'validate_default': True, 'populat...1, 'gradient_accumulation_steps': 1, 'micro_batch_size': 1, 'seed': 42, 'train_iters': 0}, field_name='eval_frequency')

    @field_validator("eval_frequency", mode="after")
    def validate_eval_frequency(cls, v: int, info: ValidationInfo) -> int:
        if (
>           info.data["data"].eval_sources
            or info.data["data"].train_eval_split[1] > 0.0
        ):
E       KeyError: 'data'

arctic_training/config/trainer.py:158: KeyError
======================================================================================= short test summary info ========================================================================================
FAILED tests/trainer/test_trainer_validation.py::test_unregistered_trainer - KeyError: 'data'
========================================================================================== 1 failed in 0.72s ===========================================================================================

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants