From 051d5a1f615e73ab07b3e94563dd8b5babf2d981 Mon Sep 17 00:00:00 2001 From: Laura O'Mahony <64503534+lomahony@users.noreply.github.com> Date: Mon, 30 Oct 2023 17:22:53 +0000 Subject: [PATCH] updating PPOTrainer docstring (#897) * adding specific dict structure to tracker_kwargs doc string to enable changing tracker params like wandb experiment name for ease, avoids needing to go deep into accelerate source * push changes * set default dict * refactor * use typing extension --------- Co-authored-by: Laura O'Mahony Co-authored-by: Costa Huang --- trl/trainer/ppo_config.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 92aac8509f..feda1c2b71 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os import sys import warnings @@ -19,6 +20,7 @@ import numpy as np import tyro +from typing_extensions import Annotated from trl.trainer.utils import exact_div @@ -26,6 +28,9 @@ from ..import_utils import is_wandb_available +JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)] + + @dataclass class PPOConfig: """ @@ -49,15 +54,15 @@ class PPOConfig: """The reward model to use - used only for tracking purposes""" remove_unused_columns: bool = True """Remove unused columns from the dataset if `datasets.Dataset` is used""" - tracker_kwargs: dict = field(default_factory=dict) - """Keyword arguments for the tracker (e.g. wandb_project)""" - accelerator_kwargs: dict = field(default_factory=dict) + tracker_kwargs: JSONDict = field(default_factory=dict) + """Keyword arguments for the tracker (e.g. python ppo.py --ppo_config.tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'""" + accelerator_kwargs: JSONDict = field(default_factory=dict) """Keyword arguments for the accelerator""" - project_kwargs: dict = field(default_factory=dict) + project_kwargs: JSONDict = field(default_factory=dict) """Keyword arguments for the accelerator project config (e.g. `logging_dir`)""" tracker_project_name: str = "trl" """Name of project to use for tracking""" - push_to_hub_if_best_kwargs: dict = field(default_factory=dict) + push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict) """Keyword arguments for pushing model to the hub during training (e.g. repo_id)""" # hyperparameters