Skip to content

Commit

Permalink
Merged in P2E-Dreamer-V2 (pull request #14)
Browse files Browse the repository at this point in the history
P2E Dreamer V2
  • Loading branch information
belerico committed Jul 17, 2023
2 parents 6789bb6 + f77389e commit b223f32
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 501 deletions.
22 changes: 20 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Sequence

import gymnasium as gym
import jsonlines
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -37,8 +38,7 @@
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.utils import compute_lambda_values, polynomial_decay

# Decomment the following two lines if you are using MineDojo on an headless machine
# os.environ["MINEDOJO_HEADLESS"] = "1"
os.environ["MINEDOJO_HEADLESS"] = "1"


def train(
Expand Down Expand Up @@ -729,6 +729,12 @@ def main():
max_decay_steps=max_step_expl_decay,
)

# Stats file info
if "minedojo" in args.env_id:
stats_dir = os.path.join(log_dir, "stats")
os.makedirs(stats_dir, exist_ok=True)
stats_filename = os.path.join(stats_dir, "stats.jsonl")

# Get the first environment observation and start the optimization
episode_steps = []
o, infos = env.reset(seed=args.seed)
Expand Down Expand Up @@ -785,6 +791,18 @@ def main():
else:
real_actions = np.array([real_act.cpu().argmax() for real_act in real_actions])

# Save stats
if "minedojo" in args.env_id:
with jsonlines.open(stats_filename, mode="a") as writer:
writer.write(
{
"life_stats": infos["life_stats"],
"location_stats": infos["location_stats"],
"action": real_actions.tolist(),
"biomeid": infos["biomeid"],
}
)

step_data["is_first"] = copy.deepcopy(step_data["dones"])
o, rewards, dones, truncated, infos = env.step(real_actions.reshape(env.action_space.shape))
dones = np.logical_or(dones, truncated)
Expand Down
Loading

0 comments on commit b223f32

Please sign in to comment.