Skip to content

Commit

Permalink
Updated training and inference tutorials along with small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
uchendui committed Aug 29, 2024
1 parent 39360bb commit c40fd44
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 77 deletions.
43 changes: 15 additions & 28 deletions a2perf/domains/tfa/suite_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,24 @@
for the final step of an episode. To prevent that we extract the step limit
from the environment specs and utilize our TimeLimit wrapper.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function

import json
import os
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from typing import Sequence
from typing import Text
from typing import Any, Callable, Dict, Optional, Sequence, Text

import gin
import gymnasium as gym
import numpy as np
from absl import logging
from tf_agents.environments import py_environment
from tf_agents.environments import wrappers
from tf_agents.environments import py_environment, wrappers
from tf_agents.typing import types

from a2perf.domains import circuit_training # noqa: F401
from a2perf.domains import quadruped_locomotion # noqa: F401
from a2perf.domains import web_navigation # noqa: F401
from a2perf.domains.tfa import gym_wrapper
from a2perf.domains.web_navigation.gwob.CoDE import vocabulary_node

TimeLimitWrapperType = Callable[
[py_environment.PyEnvironment, int], py_environment.PyEnvironment
Expand Down Expand Up @@ -184,16 +177,14 @@ def create_domain(
):
if env_name in WEB_NAVIGATION_ENVS:
# noinspection PyUnresolvedReferences
from a2perf.domains import web_navigation # noqa: F401
from a2perf.domains.web_navigation.gwob.CoDE import vocabulary_node

save_vocab_dir = os.path.join(root_dir, "vocabulary")
reload_vocab = env_kwargs.pop("reload_vocab", True)
vocab_type = env_kwargs.pop("vocab_type", "threaded")
if vocab_type == "threaded":
global_vocab = vocabulary_node.LockedThreadedVocabulary()
elif vocab_type == "unlocked":
global_vocab = vocabulary_node.UnlockedVocabulary()
vocabulary_node.UnlockedVocabulary()
elif vocab_type == "multiprocessing":
global_vocab = vocabulary_node.LockedMultiprocessingVocabulary()
else:
Expand All @@ -209,14 +200,14 @@ def create_domain(
global_vocab.restore(state=global_vocab_dict)
seed = int(os.environ.get("SEED", None))
num_websites = int(os.environ.get("NUM_WEBSITES", None))
difficulty = int(os.environ.get("DIFFICULTY_LEVEL", None))
# difficulty = int(os.environ.get("DIFFICULTY_LEVEL", None))

env_kwargs.update(
{
"global_vocabulary": global_vocab,
"seed": seed,
"num_websites": num_websites,
"difficulty": difficulty,
# "difficulty": difficulty,
"browser_args": dict(
threading=False,
chrome_options={
Expand All @@ -230,30 +221,26 @@ def create_domain(
)
env_wrappers = [wrappers.ActionClipWrapper] + list(env_wrappers)
elif env_name in CIRCUIT_TRAINING_ENVS:
# noinspection PyUnresolvedReferences
from a2perf.domains import circuit_training # noqa: F401

env_kwargs.pop("netlist", None)
netlist_file_path = os.environ.get("NETLIST_PATH", None)
# netlist_file_path = os.environ.get("NETLIST_PATH", None)
seed = int(os.environ.get("SEED", None))
init_placement_file_path = os.environ.get("INIT_PLACEMENT_PATH", None)
std_cell_placer_mode = os.environ.get("STD_CELL_PLACER_MODE", None)
# init_placement_file_path = os.environ.get("INIT_PLACEMENT_PATH", None)
# std_cell_placer_mode = os.environ.get("STD_CELL_PLACER_MODE", None)
env_kwargs.update(
{
"global_seed": seed,
"netlist_file": netlist_file_path,
"init_placement": init_placement_file_path,
# "netlist_file": netlist_file_path,
# "init_placement": init_placement_file_path,
"output_plc_file": os.path.join(root_dir, "output.plc"),
"std_cell_placer_mode": std_cell_placer_mode,
# "std_cell_placer_mode": std_cell_placer_mode,
}
)
env_wrappers = [wrappers.ActionClipWrapper] + list(env_wrappers)
elif env_name in QUADRUPED_LOCOMOTION_ENVS:
# noinspection PyUnresolvedReferences
from a2perf.domains import quadruped_locomotion # noqa: F401

motion_file_path = os.environ.get("MOTION_FILE_PATH", None)
env_kwargs["motion_files"] = [motion_file_path]
# motion_file_path = os.environ.get("MOTION_FILE_PATH", None)
# env_kwargs["motion_files"] = [motion_file_path]
env_wrappers = [wrappers.ActionClipWrapper] + list(env_wrappers)
else:
raise NotImplementedError(f"Unknown environment: {env_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import a2perf.domains.tfa.suite_gym
# Set up submission object
Submission.mode = %BenchmarkMode.INFERENCE
Submission.domain = %BenchmarkDomain.QUADRUPED_LOCOMOTION
#Submission.run_offline_metrics_only = True
Submission.run_offline_metrics_only = False
Submission.measure_emissions = True

####################################
# Set up domain
####################################
suite_gym.create_domain.env_name = "QuadrupedLocomotion-DogPace-v0"
suite_gym.create_domain.mode='test'
suite_gym.create_domain.num_parallel_envs=1

Expand All @@ -30,11 +31,12 @@ Submission.time_participant_code = True
# SYSTEM METRICS SETUP
# ----------------------
# Set up codecarbon for system metrics
track_emissions_decorator.project_name = 'a2perf_quadruped_locomotion_inference_debug'
track_emissions_decorator.project_name = 'a2perf_quadruped_locomotion_inference'
track_emissions_decorator.measure_power_secs = 1
track_emissions_decorator.save_to_file = True # Save data to file
track_emissions_decorator.save_to_logger = False # Do not save data to logger
track_emissions_decorator.gpu_ids = None # Enter a list of specific GPU IDs to track if desired
track_emissions_decorator.log_level = 'info' # Log level set to 'info'
track_emissions_decorator.country_iso_code = 'USA'
track_emissions_decorator.region = 'Massachusetts'
track_emissions_decorator.offline = True
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Submission.measure_emissions = True
####################################
# Set up domain
####################################
suite_gym.create_domain.mode.env_name = "QuadrupedLocomotion-DogSpin-v0"
suite_gym.create_domain.mode='test'
suite_gym.create_domain.num_parallel_envs=1

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Submission.measure_emissions = True
####################################
# Set up domain
####################################
suite_gym.create_domain.env_name = "QuadrupedLocomotion-DogTrot-v0"
suite_gym.create_domain.mode='test'
suite_gym.create_domain.num_parallel_envs=1

Expand Down
2 changes: 1 addition & 1 deletion a2perf/submission/configs/quadruped_locomotion/train.gin
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import a2perf.submission.submission_util
# Set up submission object
Submission.mode = %a2perf.constants.BenchmarkMode.TRAIN
Submission.domain = %a2perf.constants.BenchmarkDomain.QUADRUPED_LOCOMOTION
Submission.run_offline_metrics_only=False
Submission.run_offline_metrics_only = False
Submission.measure_emissions=True


Expand Down
8 changes: 0 additions & 8 deletions a2perf/submission/main_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import gin
from absl import app, flags, logging

from a2perf.constants import BenchmarkMode
from a2perf.submission import submission_util

_GIN_CONFIG = flags.DEFINE_string(
Expand Down Expand Up @@ -32,12 +31,6 @@
_RUN_OFFLINE_METRICS_ONLY = flags.DEFINE_bool(
"run-offline-metrics-only", False, "Whether to run offline metrics only."
)
_MODE = flags.DEFINE_enum(
"mode",
"train",
["train", "inference", "generalization"],
"Mode of the submission. train, inference, or generalization.",
)


def main(_):
Expand All @@ -54,7 +47,6 @@ def main(_):
logging.info("Adding extra gin binding: %s", binding)

submission = submission_util.Submission(
mode=BenchmarkMode(_MODE.value),
root_dir=_ROOT_DIR.value,
metric_values_dir=_METRIC_VALUES_DIR.value,
participant_module_path=_PARTICIPANT_MODULE_PATH.value,
Expand Down
32 changes: 20 additions & 12 deletions a2perf/submission/submission_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def _load_module(module_path, filename):
return module, spec


def _load_policy(module_path, env):
def _load_policy(module_path, env, participant_args=None):
"""Loads the policy from the participant's module."""
with working_directory(module_path):
participant_module, participant_module_spec = _load_module(
module_path, "inference.py"
)
policy = participant_module.load_policy(env)
policy = participant_module.load_policy(env, **(participant_args or {}))
return policy, participant_module


Expand All @@ -159,6 +159,7 @@ def perform_rollouts(
gin_config_str=None,
absl_flags=None,
rollout_rewards_queue=None,
participant_args=None,
):
"""Performs rollouts using the given policy.
Expand All @@ -175,7 +176,11 @@ def perform_rollouts(
"""
setup_subprocess_env(gin_config_str, absl_flags)
env = create_domain_fn()
policy, participant_module = _load_policy(module_path, env)
if participant_args is None:
participant_args = {}
policy, participant_module = _load_policy(
module_path, env, participant_args=participant_args
)
episode_reward_metric = py_metrics.AverageReturnMetric()
rollout_actor = actor.Actor(
env=env,
Expand Down Expand Up @@ -284,9 +289,7 @@ def _perform_rollout_task(
for key, value in generalization_env_vars.items():
os.environ[key] = value

create_domain_fn = functools.partial(
suite_gym.create_domain, env_name=domain.value, root_dir=root_dir
)
create_domain_fn = functools.partial(suite_gym.create_domain, root_dir=root_dir)
all_rewards = perform_rollouts(
module_path=participant_module_path,
create_domain_fn=create_domain_fn,
Expand Down Expand Up @@ -456,7 +459,10 @@ def _perform_rollouts(
setup_subprocess_env(self.gin_config_str, self.absl_flags)

create_domain_fn = functools.partial(
suite_gym.create_domain, env_name=self.domain.value, root_dir=self.root_dir
suite_gym.create_domain,
# env_name=self.domain.value,
root_dir=self.root_dir,
# load_kwargs=self.participant_args,
)
if measure_emissions:

Expand All @@ -473,6 +479,7 @@ def perform_rollouts_and_track_emissions():
self.gin_config_str,
self.absl_flags,
rollout_rewards_queue,
self.participant_args,
),
)
rollout_process.start()
Expand All @@ -486,6 +493,7 @@ def perform_rollouts_and_track_emissions():
module_path=self.participant_module_path,
gin_config_str=self.gin_config_str,
absl_flags=self.absl_flags,
participant_args=self.participant_args,
)

def _run_training_benchmark(self):
Expand Down Expand Up @@ -552,10 +560,8 @@ def _run_generalization_benchmark(self):

def _run_inference_benchmark(self):
if not self.run_offline_metrics_only:
logging.info("Creating Gymnasium domain...")
env = suite_gym.create_domain(
env_name=self.domain.value, root_dir=self.root_dir
)
logging.info("Creating Gymnasium environment...")
env = suite_gym.create_domain(root_dir=self.root_dir)
logging.info("Successfully created domain")

logging.info("Generating inference data...")
Expand All @@ -566,7 +572,9 @@ def _run_inference_benchmark(self):

logging.info("Loading the policy for inference...")
participant_policy, participant_module = _load_policy(
module_path=self.participant_module_path, env=env
module_path=self.participant_module_path,
env=env,
participant_args=self.participant_args,
)

# Only include time_step_spec if the participant policy has it as an
Expand Down
Loading

0 comments on commit c40fd44

Please sign in to comment.