Skip to content

Commit

Permalink
Merge pull request #198 from edbeeching/add_sb3_sac_onnx_export
Browse files Browse the repository at this point in the history
Add sb3 sac onnx export
  • Loading branch information
Ivan-267 authored Aug 27, 2024
2 parents 931cd28 + 14bfefe commit 5ccfa8e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 39 deletions.
4 changes: 2 additions & 2 deletions examples/stable_baselines3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor

from godot_rl.core.utils import can_import
from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx
from godot_rl.wrappers.onnx.stable_baselines_export import export_model_as_onnx
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

# To download the env source and binary:
Expand Down Expand Up @@ -115,7 +115,7 @@ def handle_onnx_export():
if args.onnx_export_path is not None:
path_onnx = pathlib.Path(args.onnx_export_path).with_suffix(".onnx")
print("Exporting onnx to: " + os.path.abspath(path_onnx))
export_ppo_model_as_onnx(model, str(path_onnx))
export_model_as_onnx(model, str(path_onnx))


def handle_model_save():
Expand Down
92 changes: 57 additions & 35 deletions godot_rl/wrappers/onnx/stable_baselines_export.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import torch
from gymnasium.vector.utils import spaces
from stable_baselines3 import PPO
from stable_baselines3 import PPO, SAC


class OnnxableMultiInputPolicy(torch.nn.Module):
class OnnxablePolicy(torch.nn.Module):
def __init__(
self,
obs_keys,
features_extractor,
mlp_extractor,
action_net,
value_net,
use_obs_array,
obs_keys=None,
features_extractor=None,
mlp_extractor=None,
action_net=None,
value_net=None,
use_obs_array=None,
actor=None,
):
super().__init__()
self.obs_keys = obs_keys
Expand All @@ -20,10 +21,12 @@ def __init__(
self.action_net = action_net
self.value_net = value_net
self.use_obs_array = use_obs_array
self.actor = actor

def forward(self, obs, state_ins):
# NOTE: You may have to process (normalize) observation in the correct
# way before using this. See `common.preprocessing.preprocess_obs`
def forward_sac(self, observation: torch.Tensor, state_ins):
return self.actor(observation, deterministic=True), state_ins

def forward_ppo(self, obs, state_ins):
features = None

if self.use_obs_array:
Expand All @@ -35,31 +38,47 @@ def forward(self, obs, state_ins):
action_hidden, value_hidden = self.mlp_extractor(features)
return self.action_net(action_hidden), state_ins

def forward(self, obs, state_ins):
if self.actor:
return self.forward_sac(obs, state_ins)
else:
return self.forward_ppo(obs, state_ins)


def export_model_as_onnx(model, onnx_model_path: str, use_obs_array: bool = False):
policy = model.policy.to("cpu")
dummy_input = None
onnxable_model = None

if isinstance(model, SAC):
assert use_obs_array, "SAC ONNX export works with use_obs_array=True, MLPPolicy and SBGSingleObsEnv only."

if isinstance(model, PPO):
onnxable_model = OnnxablePolicy(
["obs"],
policy.features_extractor,
policy.mlp_extractor,
policy.action_net,
policy.value_net,
use_obs_array,
)
if use_obs_array:
dummy_input = torch.unsqueeze(torch.tensor(model.observation_space.sample()), 0)
else:
dummy_input = dict(model.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)
dummy_input = [v for v in dummy_input.values()]

def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool = False):
ppo_policy = ppo.policy.to("cpu")
onnxable_model = OnnxableMultiInputPolicy(
["obs"],
ppo_policy.features_extractor,
ppo_policy.mlp_extractor,
ppo_policy.action_net,
ppo_policy.value_net,
use_obs_array,
)

if use_obs_array:
dummy_input = torch.unsqueeze(torch.tensor(ppo.observation_space.sample()), 0)
else:
dummy_input = dict(ppo.observation_space.sample())
for k, v in dummy_input.items():
dummy_input[k] = torch.from_numpy(v).unsqueeze(0)
dummy_input = [v for v in dummy_input.values()]
elif isinstance(model, SAC):
onnxable_model = OnnxablePolicy(actor=model.policy.actor)
dummy_input = torch.randn(1, *model.observation_space.shape)

torch.onnx.export(
onnxable_model,
args=(dummy_input, torch.zeros(1).float()),
f=onnx_model_path,
opset_version=9,
opset_version=17,
input_names=["obs", "state_ins"],
output_names=["output", "state_outs"],
dynamic_axes={
Expand All @@ -70,11 +89,14 @@ def export_ppo_model_as_onnx(ppo: PPO, onnx_model_path: str, use_obs_array: bool
},
)

# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
# (the output from onnx will be the action logits for each discrete action,
# while the output from sb3 will be a single int)
if not isinstance(ppo.action_space, spaces.MultiDiscrete):
verify_onnx_export(ppo, onnx_model_path, use_obs_array=use_obs_array)
# We only verify with PPO currently due to different output shape with SAC
# (this can be updated in the future)
if isinstance(model, PPO):
# If the space is MultiDiscrete, we skip verifying as action output will have an expected mismatch
# (the output from onnx will be the action logits for each discrete action,
# while the output from sb3 will be a single int)
if not isinstance(model.action_space, spaces.MultiDiscrete):
verify_onnx_export(model, onnx_model_path, use_obs_array=use_obs_array)


def verify_onnx_export(ppo: PPO, onnx_model_path: str, num_tests=10, use_obs_array: bool = False):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_sb3_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
def test_pytorch_vs_onnx(env_name, port):
from stable_baselines3 import PPO

from godot_rl.wrappers.onnx.stable_baselines_export import export_ppo_model_as_onnx, verify_onnx_export
from godot_rl.wrappers.onnx.stable_baselines_export import export_model_as_onnx, verify_onnx_export
from godot_rl.wrappers.stable_baselines_wrapper import StableBaselinesGodotEnv

env_path = f"examples/godot_rl_{env_name}/bin/{env_name}.x86_64"
Expand All @@ -39,6 +39,6 @@ def test_pytorch_vs_onnx(env_name, port):
tensorboard_log="logs/log",
)

export_ppo_model_as_onnx(ppo, f"{env_name}_tmp.onnx")
export_model_as_onnx(ppo, f"{env_name}_tmp.onnx")
verify_onnx_export(ppo, f"{env_name}_tmp.onnx")
os.remove(f"{env_name}_tmp.onnx")

0 comments on commit 5ccfa8e

Please sign in to comment.