Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init add of wandb ai writer #5741

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ml-agents/mlagents/plugins/stats_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mlagents_envs import logging_util
from mlagents.plugins import ML_AGENTS_STATS_WRITER
from mlagents.trainers.settings import RunOptions
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter
from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter, WandbWriter


logger = logging_util.get_logger(__name__)
Expand All @@ -25,6 +25,7 @@ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
* A TensorboardWriter to write information to TensorBoard
* A GaugeWriter to record our internal stats
* A ConsoleWriter to output to stdout.
* A Wandb.AI Writer
"""
checkpoint_settings = run_options.checkpoint_settings
return [
Expand All @@ -35,6 +36,7 @@ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]:
),
GaugeWriter(),
ConsoleWriter(),
WandbWriter()
]


Expand Down
24 changes: 24 additions & 0 deletions ml-agents/mlagents/trainers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import time
from threading import RLock
import wandb

from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod

Expand Down Expand Up @@ -286,6 +287,29 @@ def add_property(
self.summary_writers[category].flush()


class WandbWriter(StatsWriter):
def __init__(
self,
config: dict
):
"""
A Weights and Biases Wrapper that will add stats to your wandb.ai board.
"""
wandb.init(reinit=True,
config=config)

def write_stats(
self,
category : str,
values : dict,
step : int
) -> None:
"""
Write some stats for a given category and step
"""
wandb.log({category : values}, step=step)


class StatsReporter:
writers: List[StatsWriter] = []
stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list))
Expand Down
18 changes: 18 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
StatsSummary,
GaugeWriter,
ConsoleWriter,
WandbWriter,
StatsPropertyType,
StatsAggregationMethod,
)
Expand Down Expand Up @@ -248,3 +249,20 @@ def test_selfplay_console_writer(self):
self.assertIn(
"Mean Reward: 1.000. Std of Reward: 0.000. Training.", cm.output[0]
)


class WandbWriterTest(unittest.TestCase):
def test_wandb_full(self):
category = "GeneralStuff"
config = {"caller" : "ml-agents"}
wandb_writer = WandbWriter(config=config)
wandb_writer.write_stats(
category = category,
values = {
"Environment/Cumulative Reward": -15,
"Is Training": True,
"Self-play/ELO": 1.0,
},
step = 10,
)