-
Notifications
You must be signed in to change notification settings - Fork 39
/
learner.py
137 lines (109 loc) · 4.99 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""Implementations of algorithms for continuous control."""
from typing import Optional, Sequence, Tuple
import jax
import jax.numpy as jnp
import numpy as np
import optax
import policy
import value_net
from actor import update as awr_update_actor
from common import Batch, InfoDict, Model, PRNGKey
from critic import update_q, update_v
def target_update(critic: Model, target_critic: Model, tau: float) -> Model:
new_target_params = jax.tree_multimap(
lambda p, tp: p * tau + tp * (1 - tau), critic.params,
target_critic.params)
return target_critic.replace(params=new_target_params)
@jax.jit
def _update_jit(
rng: PRNGKey, actor: Model, critic: Model, value: Model,
target_critic: Model, batch: Batch, discount: float, tau: float,
expectile: float, temperature: float
) -> Tuple[PRNGKey, Model, Model, Model, Model, Model, InfoDict]:
new_value, value_info = update_v(target_critic, value, batch, expectile)
key, rng = jax.random.split(rng)
new_actor, actor_info = awr_update_actor(key, actor, target_critic,
new_value, batch, temperature)
new_critic, critic_info = update_q(critic, new_value, batch, discount)
new_target_critic = target_update(new_critic, target_critic, tau)
return rng, new_actor, new_critic, new_value, new_target_critic, {
**critic_info,
**value_info,
**actor_info
}
class Learner(object):
def __init__(self,
seed: int,
observations: jnp.ndarray,
actions: jnp.ndarray,
actor_lr: float = 3e-4,
value_lr: float = 3e-4,
critic_lr: float = 3e-4,
hidden_dims: Sequence[int] = (256, 256),
discount: float = 0.99,
tau: float = 0.005,
expectile: float = 0.8,
temperature: float = 0.1,
dropout_rate: Optional[float] = None,
max_steps: Optional[int] = None,
opt_decay_schedule: str = "cosine"):
"""
An implementation of the version of Soft-Actor-Critic described in https://arxiv.org/abs/1801.01290
"""
self.expectile = expectile
self.tau = tau
self.discount = discount
self.temperature = temperature
rng = jax.random.PRNGKey(seed)
rng, actor_key, critic_key, value_key = jax.random.split(rng, 4)
action_dim = actions.shape[-1]
actor_def = policy.NormalTanhPolicy(hidden_dims,
action_dim,
log_std_scale=1e-3,
log_std_min=-5.0,
dropout_rate=dropout_rate,
state_dependent_std=False,
tanh_squash_distribution=False)
if opt_decay_schedule == "cosine":
schedule_fn = optax.cosine_decay_schedule(-actor_lr, max_steps)
optimiser = optax.chain(optax.scale_by_adam(),
optax.scale_by_schedule(schedule_fn))
else:
optimiser = optax.adam(learning_rate=actor_lr)
actor = Model.create(actor_def,
inputs=[actor_key, observations],
tx=optimiser)
critic_def = value_net.DoubleCritic(hidden_dims)
critic = Model.create(critic_def,
inputs=[critic_key, observations, actions],
tx=optax.adam(learning_rate=critic_lr))
value_def = value_net.ValueCritic(hidden_dims)
value = Model.create(value_def,
inputs=[value_key, observations],
tx=optax.adam(learning_rate=value_lr))
target_critic = Model.create(
critic_def, inputs=[critic_key, observations, actions])
self.actor = actor
self.critic = critic
self.value = value
self.target_critic = target_critic
self.rng = rng
def sample_actions(self,
observations: np.ndarray,
temperature: float = 1.0) -> jnp.ndarray:
rng, actions = policy.sample_actions(self.rng, self.actor.apply_fn,
self.actor.params, observations,
temperature)
self.rng = rng
actions = np.asarray(actions)
return np.clip(actions, -1, 1)
def update(self, batch: Batch) -> InfoDict:
new_rng, new_actor, new_critic, new_value, new_target_critic, info = _update_jit(
self.rng, self.actor, self.critic, self.value, self.target_critic,
batch, self.discount, self.tau, self.expectile, self.temperature)
self.rng = new_rng
self.actor = new_actor
self.critic = new_critic
self.value = new_value
self.target_critic = new_target_critic
return info