Skip to content

Commit

Permalink
updating PPOTrainer docstring (huggingface#897)
Browse files Browse the repository at this point in the history
* adding specific dict structure to tracker_kwargs doc string to enable changing tracker params like wandb experiment name for ease, avoids needing to go deep into accelerate source

* push changes

* set default dict

* refactor

* use typing extension

---------

Co-authored-by: Laura O'Mahony <[email protected]>
Co-authored-by: Costa Huang <[email protected]>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent 2068fdc commit 051d5a1
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import sys
import warnings
Expand All @@ -19,13 +20,17 @@

import numpy as np
import tyro
from typing_extensions import Annotated

from trl.trainer.utils import exact_div

from ..core import flatten_dict
from ..import_utils import is_wandb_available


JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)]


@dataclass
class PPOConfig:
"""
Expand All @@ -49,15 +54,15 @@ class PPOConfig:
"""The reward model to use - used only for tracking purposes"""
remove_unused_columns: bool = True
"""Remove unused columns from the dataset if `datasets.Dataset` is used"""
tracker_kwargs: dict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. wandb_project)"""
accelerator_kwargs: dict = field(default_factory=dict)
tracker_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for the tracker (e.g. python ppo.py --ppo_config.tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'"""
accelerator_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for the accelerator"""
project_kwargs: dict = field(default_factory=dict)
project_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for the accelerator project config (e.g. `logging_dir`)"""
tracker_project_name: str = "trl"
"""Name of project to use for tracking"""
push_to_hub_if_best_kwargs: dict = field(default_factory=dict)
push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict)
"""Keyword arguments for pushing model to the hub during training (e.g. repo_id)"""

# hyperparameters
Expand Down

0 comments on commit 051d5a1

Please sign in to comment.