forked from Kaiyotech/Opti
-
Notifications
You must be signed in to change notification settings - Fork 0
/
learner_pinch.py
138 lines (120 loc) · 6.03 KB
/
learner_pinch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import wandb
import torch.jit
from torch.nn import Linear, Sequential, LeakyReLU
from redis import Redis
from rocket_learn.agent.actor_critic_agent import ActorCriticAgent
from rocket_learn.agent.discrete_policy import DiscretePolicy
from rocket_learn.ppo import PPO
from rocket_learn.rollout_generator.redis.redis_rollout_generator import RedisRolloutGenerator
from CoyoteObs import CoyoteObsBuilder
from CoyoteParser import CoyoteAction
import numpy as np
from rewards import ZeroSumReward
import Constants_pinch
from utils.misc import count_parameters
import os
from torch import set_num_threads
from rocket_learn.utils.stat_trackers.common_trackers import Speed, Demos, TimeoutRate, Touch, EpisodeLength, Boost, \
BehindBall, TouchHeight, DistToBall, AirTouch, AirTouchHeight, BallHeight, BallSpeed, CarOnGround, GoalSpeed,\
MaxGoalSpeed
# ideas for models:
# get to ball as fast as possible, sometimes with no boost, rewards exist
# pinches (ceiling and kuxir and team?), score in as few touches as possible with high velocity
# half flip, wavedash, wall dash, how to do this one?
# lix reset?
# normal play as well as possible, rewards exist
# aerial play without pinch, rewards exist
# kickoff, 5 second terminal, reward ball distance into opp half
set_num_threads(1)
if __name__ == "__main__":
frame_skip = Constants_pinch.FRAME_SKIP
half_life_seconds = Constants_pinch.TIME_HORIZON
fps = 120 / frame_skip
gamma = np.exp(np.log(0.5) / (fps * half_life_seconds))
config = dict(
actor_lr=1e-4,
critic_lr=1e-4,
n_steps=Constants_pinch.STEP_SIZE,
batch_size=100_000,
minibatch_size=50_000,
epochs=30,
gamma=gamma,
save_every=10,
model_every=100,
ent_coef=0.01,
)
run_id = "pinch_run1.01"
wandb.login(key=os.environ["WANDB_KEY"])
logger = wandb.init(dir="./wandb_store",
name="Pinch_Run1.01",
project="Opti",
entity="kaiyotech",
id=run_id,
config=config,
settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
resume=True,
)
redis = Redis(username="user1", password=os.environ["redis_user1_key"], db=Constants_pinch.DB_NUM) # host="192.168.0.201",
redis.delete("worker-ids")
stat_trackers = [
Speed(normalize=True), Demos(), TimeoutRate(), Touch(), EpisodeLength(), Boost(), BehindBall(), TouchHeight(),
DistToBall(), AirTouch(), AirTouchHeight(), BallHeight(), BallSpeed(normalize=True), CarOnGround(),
GoalSpeed(), MaxGoalSpeed(),
]
rollout_gen = RedisRolloutGenerator("Opti_pinch",
redis,
lambda: CoyoteObsBuilder(expanding=True, tick_skip=Constants_pinch.FRAME_SKIP,
team_size=3, extra_boost_info=False),
lambda: ZeroSumReward(zero_sum=Constants_pinch.ZERO_SUM,
goal_w=10,
aerial_goal_w=5,
double_tap_w=20,
concede_w=-10,
velocity_pb_w=0.025,
velocity_bg_w=1,
acel_ball_w=2.5,
punish_low_touch_w=-0.5, # increase later
team_spirit=1,
cons_air_touches_w=0.75,
jump_touch_w=1,
wall_touch_w=1,
),
lambda: CoyoteAction(),
save_every=logger.config.save_every,
model_every=logger.config.model_every,
logger=logger,
clear=False,
stat_trackers=stat_trackers,
# gamemodes=("1v1", "2v2", "3v3"),
max_age=1,
)
critic = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
Linear(512, 1))
actor = Sequential(Linear(222, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(), Linear(512, 512), LeakyReLU(),
Linear(512, 373))
actor = DiscretePolicy(actor, (373,))
optim = torch.optim.Adam([
{"params": actor.parameters(), "lr": logger.config.actor_lr},
{"params": critic.parameters(), "lr": logger.config.critic_lr}
])
agent = ActorCriticAgent(actor=actor, critic=critic, optimizer=optim)
print(f"Gamma is: {gamma}")
count_parameters(agent)
alg = PPO(
rollout_gen,
agent,
ent_coef=logger.config.ent_coef,
n_steps=logger.config.n_steps,
batch_size=logger.config.batch_size,
minibatch_size=logger.config.minibatch_size,
epochs=logger.config.epochs,
gamma=logger.config.gamma,
logger=logger,
zero_grads_with_none=True,
disable_gradient_logging=True,
)
alg.load("pinch_saves/Opti_1671075022.2357047/Opti_8470/checkpoint.pt")
alg.agent.optimizer.param_groups[0]["lr"] = logger.config.actor_lr
alg.agent.optimizer.param_groups[1]["lr"] = logger.config.critic_lr
alg.run(iterations_per_save=logger.config.save_every, save_dir="pinch_saves")