From aa3ddaeccec764fbd680ee23e64489ac07ce2edb Mon Sep 17 00:00:00 2001 From: phisn Date: Fri, 3 May 2024 22:13:27 +0200 Subject: [PATCH] Refactor SAC --- packages/learning/src/main copy.ts | 692 ------------------ packages/learning/src/main.ts | 14 +- .../soft-actor-critic/soft-actor-critic.ts | 199 +++-- 3 files changed, 109 insertions(+), 796 deletions(-) delete mode 100644 packages/learning/src/main copy.ts diff --git a/packages/learning/src/main copy.ts b/packages/learning/src/main copy.ts deleted file mode 100644 index f8feb496..00000000 --- a/packages/learning/src/main copy.ts +++ /dev/null @@ -1,692 +0,0 @@ -import * as tf from "@tensorflow/tfjs" -import { Buffer } from "buffer" -import { EntityWith, MessageCollector } from "runtime-framework" -import { WorldModel } from "runtime/proto/world" -import { LevelCapturedMessage } from "runtime/src/core/level-capture/level-captured-message" -import { RocketDeathMessage } from "runtime/src/core/rocket/rocket-death-message" -import { RuntimeComponents } from "runtime/src/core/runtime-components" -import { Runtime, newRuntime } from "runtime/src/runtime" -import { Environment, PPO } from "./ppo/ppo" - -export class PolyburnEnvironment implements Environment { - private runtime: Runtime - private currentRotation: number - private nearestLevel: EntityWith - - private captureMessages: MessageCollector - private deathMessages: MessageCollector - - private bestDistance: number - private maxTime = 60 * 30 - private remainingTime = 60 * 30 - - private worldModel: any - - private touchedFlag = false - - constructor() { - const worldStr = - "CqAJCgZOb3JtYWwSlQkKDw0fhZ3BFR+FB0Id2w/JQBItDR+FtsEVgZUDQh3bD8lAJQAAEMItpHBhQjWuR9lBPR+Fm0FFAAAAQE0AAABAEi0Nrkc/QRVt5wZCHdsPyUAlAAD4QC2kcBZCNezRjUI94KMwP0UAAABATQAAAEASLQ2k8B5CFX9qWEEd2w/JQCUAAP5BLaRwFkI17NG9Qj3gozA/RQAAAEBNAAAAQBItDeyRm0IVPzWGQR3bD8lAJQCAjUItSOHsQTX26AVDPYTr6cBFAAAAQE0AAABAEi0Nw0XwQhUcd4lAHTMeejwlAIDnQi2kcA5CNfboMkM9EK6nv0UAAABATQAAAEASLQ2PYhxDFT813EEd2w/JQCUAAM9CLaRwbEI1AMAmQz0fhbFBRQAAAEBNAAAAQBItDcM15UIVYxBJQh3bD8lAJQAAeUItUrijQjXs0fpCPZDCM0JFAAAAQE0AAABAEi0N9WiFQhXVeIhCHdsPyUAlw7WBQi3sUY9CNcO1kUI9AACBQkUAAABATQAAAEAaTgpMpHA9wXE9ukHAwP8AAEAAPYCA/wAAtIBDAAD/AIDFAEBAQP8AgMgAAICA/wBAxgC+oKD/AABGAMf///8AV0dxQry8+QBSQPHA////ABpOCkyuR3FBSOHKQf/++ABAxgAA//3wAAA/QMT/++AAQEoAQv/3wAAAPkBF/++AAADHAD//3gAAgMYAAP/vgAAAAIDD////AKxGCq////8AGpcCCpQC9qjBQpqZJEL///8AMNEAOv///wDqy9pH////AOzHNML///8AAMIAx////wAAQkDE////AABFAL3///8AAELAx////wCARgBF////AEBGgMb///8AwEYAv////wAgSQBF////AOBIgMP///8A4EjAR////wAARYDE////AAC+oMj///8AAD8AAP///wAAAODK////AGBJAEf///8AwMTASP///wAgSQAA////AEBEwMb///8AAEOAQ////wBASQC/////AAA+wEj///8AwEqAw////wAAvMBL////AODIAAD///8AQMoAQP///wAAPgBI////ACDIAAD///8AgMCARv///wCAyQAA////AEBFgMb///8AGqcCCqQCpHAZQqRwOcH///8AmFgAwP///wCAxwhU////AGDK4E3///8AwM1gyf///wAAv+DI////AKBLAMP///8AADpgyf///wCARgAA////AAA6YMv///8AQMgAAP///wAAvuDJ////AIBFYMj///8AQMyAwf///wAAtMDG////AGDLAL3///8AOMAMSP///wAkxgCu////AADC4Mj///8AAMNARv///wBgyQAA////AEDHgMP///8AwMeAQf///wAAAEBM////ACDJAAD///8AgMMAx////wAAyoBC////AAC9AMb///8AgMTARf///wCAwIDB////AABFAML///8AAMgANP///wBAxEBG////AADHAAD///8AAMFAyP///wBgyEDE////ABomCiSPQopCcT2DQv/AjQAAxAAA/+R0AAAAAMT/kwAAAEQAAP+bAAASEgoGTm9ybWFsEggKBk5vcm1hbA==" - this.worldModel = WorldModel.decode(Buffer.from(worldStr, "base64")) - - this.runtime = newRuntime(this.worldModel, "Normal") - - this.currentRotation = 0 - - const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] - const rocketPosition = rocket.components.rigidBody.translation() - - this.captureMessages = this.runtime.factoryContext.messageStore.collect("levelCaptured") - this.deathMessages = this.runtime.factoryContext.messageStore.collect("rocketDeath") - - this.nearestLevel = this.runtime.factoryContext.store - .find("level") - .filter(level => level.components.level.captured === false) - .sort( - (a, b) => - Math.abs(a.components.level.flag.x - rocketPosition.x) - - Math.abs(b.components.level.flag.y - rocketPosition.x), - )[0] - - const { distance } = this.state() - this.bestDistance = distance - } - - inputFromAction(action: number[]) { - const input = { - rotation: this.currentRotation + action[0], - thrust: action[1] > 0 ? true : false, - } - - return input - } - - step(action: number | number[]): [number[], number, boolean] { - if (typeof action === "number") { - throw new Error("Wrong action type") - } - - this.remainingTime-- - - const input = this.inputFromAction(action) - this.currentRotation += action[0] - - this.runtime.step(input) - - const { distance, observation, velMag, angDiff } = this.state() - - let newTouch = false - - if (this.nearestLevel.components.level.inCapture) { - if (!this.touchedFlag) { - newTouch = true - } - - this.touchedFlag = true - } - - const captureMessage = [...this.captureMessages].at(-1) - - if (captureMessage) { - const reward = 10000 + (this.maxTime - this.remainingTime) * 100 - return [observation, reward, true] - } - - const deathMessage = [...this.deathMessages].at(-1) - - if (deathMessage) { - const reward = -(velMag + angDiff) - return [observation, reward, true] - } - - const reward = Math.max(0, this.bestDistance - distance) - this.bestDistance = Math.min(this.bestDistance, distance) - - const done = this.remainingTime <= 0 - - return [observation, reward * 10 + (newTouch ? 100 : 0), done] - } - - state() { - const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] - - const rocketPosition = rocket.components.rigidBody.translation() - const rocketRotation = rocket.components.rigidBody.rotation() - const rocketVelocity = rocket.components.rigidBody.linvel() - - const dx = this.nearestLevel.components.level.flag.x - rocketPosition.x - const dy = this.nearestLevel.components.level.flag.y - rocketPosition.y - - const distanceToLevel = Math.sqrt(dx * dx + dy * dy) - - const angDiff = - (this.nearestLevel.components.level.flagRotation - - rocket.components.rigidBody.rotation()) % - (Math.PI * 2) - - const velMag = Math.sqrt( - rocketVelocity.x * rocketVelocity.x + rocketVelocity.y * rocketVelocity.y, - ) - - return { - distance: distanceToLevel, - observation: [ - this.nearestLevel.components.level.flag.x - rocketPosition.x, - this.nearestLevel.components.level.flag.y - rocketPosition.y, - rocketRotation, - rocketVelocity.x, - rocketVelocity.y, - ], - touched: this.touchedFlag, - angDiff, - velMag, - } - } - - reset(): number[] { - this.runtime = newRuntime(this.worldModel, "Normal") - - this.currentRotation = 0 - - const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] - const rocketPosition = rocket.components.rigidBody.translation() - - this.captureMessages = this.runtime.factoryContext.messageStore.collect("levelCaptured") - this.deathMessages = this.runtime.factoryContext.messageStore.collect("rocketDeath") - - this.nearestLevel = this.runtime.factoryContext.store - .find("level") - .filter(level => level.components.level.captured === false) - .sort( - (a, b) => - Math.abs(a.components.level.flag.x - rocketPosition.x) - - Math.abs(b.components.level.flag.y - rocketPosition.x), - )[0] - - const { distance, observation } = this.state() - - this.bestDistance = distance - this.remainingTime = this.maxTime - this.touchedFlag = false - - return observation - } -} - -export class CartPole implements Environment { - private gravity: number - private massCart: number - private massPole: number - private totalMass: number - private cartWidth: number - private cartHeight: number - private length: number - private poleMoment: number - private forceMag: number - private tau: number - - private xThreshold: number - private thetaThreshold: number - - private x: number = 0 - private xDot: number = 0 - private theta: number = 0 - private thetaDot: number = 0 - - /** - * Constructor of CartPole. - */ - constructor() { - // Constants that characterize the system. - this.gravity = 9.8 - this.massCart = 1.0 - this.massPole = 0.1 - this.totalMass = this.massCart + this.massPole - this.cartWidth = 0.2 - this.cartHeight = 0.1 - this.length = 0.5 - this.poleMoment = this.massPole * this.length - this.forceMag = 10.0 - this.tau = 0.02 // Seconds between state updates. - - // Threshold values, beyond which a simulation will be marked as failed. - this.xThreshold = 2.4 - this.thetaThreshold = (12 / 360) * 2 * Math.PI - - this.reset() - } - - /** - * Get current state as a tf.Tensor of shape [1, 4]. - */ - getStateTensor() { - return [this.x, this.xDot, this.theta, this.thetaDot] - } - - private i = 0 - private max = 0 - - /** - * Update the cart-pole system using an action. - * @param {number} action Only the sign of `action` matters. - * A value > 0 leads to a rightward force of a fixed magnitude. - * A value <= 0 leads to a leftward force of the same fixed magnitude. - */ - step(action: number | number[]): [number[], number, boolean] { - if (Array.isArray(action)) { - action = action[0] - } - - const force = action * this.forceMag - - const cosTheta = Math.cos(this.theta) - const sinTheta = Math.sin(this.theta) - - const temp = - (force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass - const thetaAcc = - (this.gravity * sinTheta - cosTheta * temp) / - (this.length * (4 / 3 - (this.massPole * cosTheta * cosTheta) / this.totalMass)) - const xAcc = temp - (this.poleMoment * thetaAcc * cosTheta) / this.totalMass - - // Update the four state variables, using Euler's method. - this.x += this.tau * this.xDot - this.xDot += this.tau * xAcc - this.theta += this.tau * this.thetaDot - this.thetaDot += this.tau * thetaAcc - - const reward = this.isDone() ? -100 : 1 - return [this.getStateTensor(), reward, this.isDone()] - } - - /** - * Set the state of the cart-pole system randomly. - */ - reset() { - this.i = 0 - // The control-theory state variables of the cart-pole system. - // Cart position, meters. - this.x = Math.random() - 0.5 - // Cart velocity. - this.xDot = (Math.random() - 0.5) * 1 - // Pole angle, radians. - this.theta = (Math.random() - 0.5) * 2 * ((6 / 360) * 2 * Math.PI) - // Pole angle velocity. - this.thetaDot = (Math.random() - 0.5) * 0.5 - - return this.getStateTensor() - } - - /** - * Determine whether this simulation is done. - * - * A simulation is done when `x` (position of the cart) goes out of bound - * or when `theta` (angle of the pole) goes out of bound. - * - * @returns {bool} Whether the simulation is done. - */ - isDone() { - return ( - this.x < -this.xThreshold || - this.x > this.xThreshold || - this.theta < -this.thetaThreshold || - this.theta > this.thetaThreshold - ) - } -} - -import "@tensorflow/tfjs-backend-webgl" -import "@tensorflow/tfjs-backend-webgpu" -import { SoftActorCritic } from "./soft-actor-critic/soft-actor-critic" - -if (false) { - tf.setBackend("cpu").then(() => { - const sac = new SoftActorCritic({ - mlpSpec: { - sizes: [32, 32], - activation: "relu", - outputActivation: "relu", - }, - - actionSize: 1, - observationSize: 4, - - maxEpisodeLength: 1000, - bufferSize: 1e6, - batchSize: 100, - updateAfter: 1000, - updateEvery: 50, - - learningRate: 1e-3, - alpha: 0.2, - gamma: 0.99, - polyak: 0.995, - }) - - sac.test() - - /* - const actor = new Actor(4, 2, { - sizes: [32, 32], - activation: "relu", - outputActivation: "relu", - }) - - actor.trainableWeights.forEach(w => { - w.write(tf.zeros(w.shape, w.dtype)) - }) - - /* - x = torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float32) - x = actor(x, True) - - const x = tf.tensor2d([[0.1, 0.2, 0.3, 0.4]]) - const r = actor.apply(x, { deterministic: true }) as tf.Tensor[] - - console.log(r[0].dataSync()) - console.log(r[1].dataSync()) - */ - }) -} - -if (true) { - tf.setBackend("cpu").then(() => { - const env = new CartPole() - - const sac = new SoftActorCritic({ - mlpSpec: { - sizes: [32, 32], - activation: "relu", - outputActivation: "relu", - }, - - actionSize: 1, - observationSize: 4, - - maxEpisodeLength: 1000, - bufferSize: 1e6, - batchSize: 100, - updateAfter: 1000, - updateEvery: 50, - - learningRate: 1e-3, - alpha: 0.2, - gamma: 0.99, - polyak: 0.995, - }) - - function currentReward() { - const acc = [] - - for (let j = 0; j < 10; ++j) { - env.reset() - - let t = 0 - - while (!env.isDone() && t < 1000) { - env.step(sac.act(env.getStateTensor(), true)) - t++ - } - - acc.push(t) - } - - // average of top 10% lifetimes - acc.sort((a, b) => b - a) - - const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 - const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 - const avg = acc.reduce((a, b) => a + b, 0) / acc.length - - return { avg, best10avg, worst10avg } - } - - let t = 0 - let updated = false - - function iteration() { - for (let i = 0; i < 16; ++i) { - t++ - - const observation = env.getStateTensor() - - let action: number[] - - if (t < 1000) { - action = [Math.random()] - } else { - action = sac.act(observation, false) - } - - const [nextObservation, reward, done] = env.step(action) - - const thisTimeUpdated = sac.observe({ - observation, - action, - reward, - nextObservation, - done, - }) - - updated ||= thisTimeUpdated - - if (done) { - if (updated) { - const { avg, best10avg, worst10avg } = currentReward() - - console.log(`Leaks: ${tf.memory().numTensors}`) - console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) - } - - env.reset() - - updated = false - } - } - - requestAnimationFrame(iteration) - } - - console.log("Start") - requestAnimationFrame(iteration) - - /* - const ppo = new PPO( - { - steps: 512, - epochs: 15, - policyLearningRate: 1e-3, - valueLearningRate: 1e-3, - clipRatio: 0.1, - targetKL: 0.01, - gamma: 0.99, - lambda: 0.95, - observationDimension: 4, - actionSpace: { - class: "Discrete", - len: 2, - }, - }, - env, - tf.sequential({ - layers: [ - tf.layers.dense({ - inputDim: 4, - units: 32, - activation: "relu", - }), - tf.layers.dense({ - units: 32, - activation: "relu", - }), - ], - }), - tf.sequential({ - layers: [ - tf.layers.dense({ - inputDim: 4, - units: 32, - activation: "relu", - }), - tf.layers.dense({ - units: 32, - activation: "relu", - }), - ], - }), - ) - - function possibleLifetime() { - const acc = [] - - for (let j = 0; j < 25; ++j) { - env.reset() - - let t = 0 - - while (!env.isDone() && t < 1000) { - env.step(ppo.act(env.getStateTensor()) as number[]) - t++ - } - - acc.push(t) - } - - // average of top 10% lifetimes - acc.sort((a, b) => b - a) - - const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 - const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 - const avg = acc.reduce((a, b) => a + b, 0) / acc.length - - return { avg, best10avg, worst10avg } - } - - let currentAverage = 0 - let i = 0 - - function iteration() { - ppo.learn(512 * i) - - const { avg, best10avg, worst10avg } = possibleLifetime() - - console.log(`Leaks: ${tf.memory().numTensors}`) - console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) - - if (avg > currentAverage) { - // await ppo.save() - currentAverage = avg - console.log("Saved") - } - - i++ - - requestAnimationFrame(iteration) - } - - console.log("Initial: ", possibleLifetime()) - - console.log("Start") - requestAnimationFrame(iteration) - - */ - }) -} - -if (false) { - tf.setBackend("cpu").then(() => { - const env = new PolyburnEnvironment() - - const inputDim = 5 - - const ppo = new PPO( - { - steps: 512, - epochs: 15, - policyLearningRate: 1e-3, - valueLearningRate: 1e-3, - clipRatio: 0.2, - targetKL: 0.01, - gamma: 0.99, - lambda: 0.95, - observationDimension: inputDim, - actionSpace: { - class: "Box", - len: 2, - low: -1, - high: 1, - }, - }, - env, - tf.sequential({ - layers: [ - tf.layers.dense({ - inputDim: inputDim, - units: 64, - activation: "relu", - }), - tf.layers.dense({ - units: 64, - activation: "relu", - }), - ], - }), - tf.sequential({ - layers: [ - tf.layers.dense({ - inputDim: inputDim, - units: 64, - activation: "relu", - }), - tf.layers.dense({ - units: 64, - activation: "relu", - }), - ], - }), - ) - - function possibleLifetime() { - let observation = env.reset() - - let totalReward = 0 - const inputs = [] - - while (true) { - const action = ppo.act(observation) - inputs.push(env.inputFromAction(action as number[])) - - const [nextObservation, reward, done] = env.step(action) - - totalReward += reward - observation = nextObservation - - if (done) { - break - } - } - - return { - totalReward, - touched: env.state().touched, - distance: env.state().distance, - inputs, - } - } - - let currentAverage = 0 - let i = 0 - - const previousTwenty: number[] = [] - - function iteration() { - ppo.learn(512 * i) - const info = possibleLifetime() - - console.log( - `Reward ${i}: reward(${info.totalReward}), distance(${info.distance}), touched(${info.touched})`, - ) - - if (info.totalReward > currentAverage && previousTwenty.length === 20) { - currentAverage = info.totalReward - console.log("Saved") - ppo.save() - } - - if (previousTwenty.length === 20) { - previousTwenty.shift() - } - - previousTwenty.push(info.totalReward) - - const avgPreviousTwenty = - previousTwenty.reduce((a, b) => a + b, 0) / previousTwenty.length - - ++i - - if ( - avgPreviousTwenty < 50 && - avgPreviousTwenty < Math.max(currentAverage, 10) * 0.5 && - previousTwenty.length === 20 - ) { - console.log("Restoring") - - ppo.restore().finally(() => { - requestAnimationFrame(iteration) - }) - } else { - requestAnimationFrame(iteration) - } - } - - ppo.restore().finally(() => { - const { totalReward, inputs } = possibleLifetime() - currentAverage = totalReward - - console.log(JSON.stringify(inputs)) - - console.log("Start with: ", currentAverage) - requestAnimationFrame(iteration) - }) - }) -} diff --git a/packages/learning/src/main.ts b/packages/learning/src/main.ts index 71040b7f..f2399e39 100644 --- a/packages/learning/src/main.ts +++ b/packages/learning/src/main.ts @@ -307,7 +307,7 @@ import "@tensorflow/tfjs-backend-webgl" import "@tensorflow/tfjs-backend-webgpu" import { SoftActorCritic } from "./soft-actor-critic/soft-actor-critic" -if (false) { +if (true) { tf.setBackend("cpu").then(() => { const env = new CartPole() @@ -333,7 +333,15 @@ if (false) { polyak: 0.995, }) - sac.learn(new CartPole()) + sac.learn(new CartPole(), { + epochs: 1000, + stepsPerEpoch: 100, + startAfter: 1000, + renderEvery: 1, + onEpochFinish: info => { + console.log("Length: ", info.episodeLength) + }, + }) function iteration() { requestAnimationFrame(iteration) @@ -424,7 +432,7 @@ if (false) { }) } -if (true) { +if (false) { tf.setBackend("cpu").then(() => { const env = new CartPole() diff --git a/packages/learning/src/soft-actor-critic/soft-actor-critic.ts b/packages/learning/src/soft-actor-critic/soft-actor-critic.ts index 61ef71b7..676ebaab 100644 --- a/packages/learning/src/soft-actor-critic/soft-actor-critic.ts +++ b/packages/learning/src/soft-actor-critic/soft-actor-critic.ts @@ -23,70 +23,31 @@ export interface Config { polyak: number } -/* -class ReplayBuffer: - """ - A simple FIFO experience replay buffer for SAC agents. - """ - - def __init__(self, obs_dim, act_dim, size): - self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32) - self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32) - self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32) - self.rew_buf = np.zeros(size, dtype=np.float32) - self.done_buf = np.zeros(size, dtype=np.float32) - self.ptr, self.size, self.max_size = 0, 0, size - - def store(self, obs, act, rew, next_obs, done): - self.obs_buf[self.ptr] = obs - self.obs2_buf[self.ptr] = next_obs - self.act_buf[self.ptr] = act - self.rew_buf[self.ptr] = rew - self.done_buf[self.ptr] = done - self.ptr = (self.ptr+1) % self.max_size - self.size = min(self.size+1, self.max_size) - - def sample_batch(self, batch_size=32): - idxs = np.random.randint(0, self.size, size=batch_size) - batch = dict(obs=self.obs_buf[idxs], - obs2=self.obs2_buf[idxs], - act=self.act_buf[idxs], - rew=self.rew_buf[idxs], - done=self.done_buf[idxs]) - - batches.append( - dict( - observation=self.obs_buf[idxs], - nextObservation=self.obs2_buf[idxs], - action=self.act_buf[idxs], - reward=self.rew_buf[idxs], - done=self.done_buf[idxs] - ) - ) - - return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()} -*/ - -export interface LearningProgress { +export interface SacLearningInfo { episodeReturn: number episodeLength: number - maxEpisodeReturn: number - maxEpisodeLength: number + lastUpdateInfo: SacUpdateInfo +} +export interface SacUpdateInfo { lossQ: number lossPolicy: number } export interface SacLearningConfig { + epochs: number + stepsPerEpoch: number + startAfter: number renderEvery: number - onStart?: () => void - onStep?: () => void - onUpdate?: () => void - onEnd?: () => void - onEpisodeEnd?: () => void + onEpochFinish?: (progress: SacLearningInfo) => void + + // - guranteed to reset environment after call + // can be used if you want to evalue the model after each episode manually + // has the advantage of not breaking episodes in the middle + onFirstEpisodeInEpoch?: (progress: SacLearningInfo) => void } export class SoftActorCritic { @@ -101,6 +62,7 @@ export class SoftActorCritic { private targetQ2: tf.LayersModel private qOptimizer: tf.Optimizer + private deterministic: boolean private episodeReturn: number private episodeLength: number @@ -116,6 +78,7 @@ export class SoftActorCritic { this.targetQ1 = critic(config.observationSize, config.actionSize, config.mlpSpec) this.targetQ2 = critic(config.observationSize, config.actionSize, config.mlpSpec) + this.deterministic = false this.episodeReturn = 0 this.episodeLength = 0 this.t = 0 @@ -124,20 +87,23 @@ export class SoftActorCritic { this.qOptimizer = tf.train.adam(config.learningRate) } - async learn(env: Environment, sacConfig: SacLearningConfig) { - const buffer = new ReplayBuffer(this.config.bufferSize, this.config.batchSize) - + async learn(env: Environment, learningConfig: SacLearningConfig) { let observation = env.reset() let episodeLength = 0 let episodeReturn = 0 - let maxEpisodeReturn = 0 - let maxEpisodeLength = 0 + let maxEpochEpisodeReturn = 0 + let maxEpochEpisodeLength = 0 - for (let t = 0; ; ++t) { + let hasEpisodeAlreadyEndedInEpoch = false + let lastUpdateInfoInEpoch = { lossQ: 0, lossPolicy: 0 } + + const steps = learningConfig.epochs * learningConfig.stepsPerEpoch + this.t + + while (this.t < steps) { let action: number[] - if (t > startFrom) { + if (this.t > learningConfig.startAfter) { action = this.act(observation, false) } else { action = [Math.random() * 2 - 1] @@ -145,8 +111,12 @@ export class SoftActorCritic { const [nextObservation, reward, done] = env.step(action) episodeLength += 1 + episodeReturn += reward - buffer.store({ + maxEpochEpisodeLength = Math.max(maxEpochEpisodeLength, episodeLength) + maxEpochEpisodeReturn = Math.max(maxEpochEpisodeReturn, episodeReturn) + + const updateInfo = this.observe({ observation, action, reward, @@ -154,27 +124,42 @@ export class SoftActorCritic { done: episodeLength === this.config.maxEpisodeLength ? false : done, }) + lastUpdateInfoInEpoch = updateInfo ?? lastUpdateInfoInEpoch observation = nextObservation if (done || episodeLength === this.config.maxEpisodeLength) { + if (hasEpisodeAlreadyEndedInEpoch === false) { + if (learningConfig.onFirstEpisodeInEpoch) { + learningConfig.onFirstEpisodeInEpoch?.({ + episodeReturn, + episodeLength, + lastUpdateInfo: lastUpdateInfoInEpoch, + }) + } + + hasEpisodeAlreadyEndedInEpoch = true + } + observation = env.reset() episodeLength = 0 } - if (t >= this.config.updateAfter && t % this.config.updateEvery === 0) { - for (let i = 0; i < this.config.updateEvery; i++) { - const batch = buffer.sample() - this.update(batch, false) - - tf.dispose(batch.observation) - tf.dispose(batch.action) - tf.dispose(batch.reward) - tf.dispose(batch.nextObservation) - tf.dispose(batch.done) + if (this.t % learningConfig.stepsPerEpoch === 0) { + if (learningConfig.onEpochFinish) { + learningConfig.onEpochFinish({ + episodeReturn: maxEpochEpisodeReturn, + episodeLength: maxEpochEpisodeLength, + lastUpdateInfo: lastUpdateInfoInEpoch, + }) } + + maxEpochEpisodeReturn = 0 + maxEpochEpisodeLength = 0 + + hasEpisodeAlreadyEndedInEpoch = true } - if (renderEvery > 0 && t % renderEvery === 0) { + if (learningConfig.renderEvery > 0 && this.t % learningConfig.renderEvery === 0) { await tf.nextFrame() } } @@ -190,33 +175,36 @@ export class SoftActorCritic { }) } - observe(experience: Experience) { - this.episodeReturn += experience.reward - this.episodeLength += 1 + observe(experience: Experience): SacUpdateInfo | undefined { this.t += 1 + this.replayBuffer.store(experience) - const done = this.episodeLength < this.config.maxEpisodeLength && experience.done + if (this.t > this.config.updateAfter && this.t % this.config.updateEvery === 0) { + let averageLossQ = 0 + let averageLossPolicy = 0 - this.replayBuffer.store({ - ...experience, - }) + for (let i = 0; i < this.config.updateEvery; i++) { + const batch = this.replayBuffer.sample() + const updateInfo = this.update(batch) - if (done || this.episodeLength === this.config.maxEpisodeLength) { - this.episodeReturn = 0 - this.episodeLength = 0 - } + tf.dispose(batch.observation) + tf.dispose(batch.action) + tf.dispose(batch.reward) + tf.dispose(batch.nextObservation) + tf.dispose(batch.done) - if (this.t > this.config.updateAfter && this.t % this.config.updateEvery === 0) { - for (let i = 0; i < this.config.updateEvery; i++) { - tf.tidy(() => { - this.update(this.replayBuffer.sample(), i === this.config.updateEvery - 1) - }) + averageLossQ += updateInfo.lossQ + averageLossPolicy += updateInfo.lossPolicy } - return true - } + averageLossQ /= this.config.updateEvery + averageLossPolicy /= this.config.updateEvery - return false + return { + lossQ: averageLossQ, + lossPolicy: averageLossPolicy, + } + } } predictQ1(observation: tf.Tensor2D, action: tf.Tensor2D) { @@ -233,9 +221,7 @@ export class SoftActorCritic { ) as tf.Tensor } - private deterministic = false - - update(batch: ExperienceTensor, last: boolean = false) { + update(batch: ExperienceTensor): SacUpdateInfo { const lossQ = () => { const q1 = this.predictQ1(batch.observation, batch.action) const q2 = this.predictQ2(batch.observation, batch.action) @@ -248,9 +234,12 @@ export class SoftActorCritic { return tf.add(errorQ1, errorQ2) as tf.Scalar } - const gradsQ = tf.variableGrads(lossQ, this.q1.getWeights().concat(this.q2.getWeights())) - this.qOptimizer.applyGradients(gradsQ.grads) - tf.dispose(gradsQ) + const { value: lossValueQ, grads: gradsQ } = tf.variableGrads( + lossQ, + this.q1.getWeights().concat(this.q2.getWeights()) as tf.Variable[], + ) + + this.qOptimizer.applyGradients(gradsQ) const lossPolicy = () => { const [pi, logpPi] = this.policy.apply(batch.observation, { @@ -265,9 +254,12 @@ export class SoftActorCritic { return tf.mean(logpPi.mul(this.config.alpha).sub(minPiQ)) as tf.Scalar } - const gradsPolicy = tf.variableGrads(lossPolicy, this.policy.getWeights()) - this.policyOptimizer.applyGradients(gradsPolicy.grads) - tf.dispose(gradsPolicy) + const { value: lossValuePolicy, grads: gradsPolicy } = tf.variableGrads( + lossPolicy, + this.policy.getWeights() as tf.Variable[], + ) + + this.policyOptimizer.applyGradients(gradsPolicy) // do polyak averaging const tau = tf.scalar(this.config.polyak) @@ -283,6 +275,11 @@ export class SoftActorCritic { this.targetQ1.setWeights(updateTargetQ1) this.targetQ2.setWeights(updateTargetQ2) + + return { + lossQ: lossValueQ.arraySync(), + lossPolicy: lossValuePolicy.arraySync(), + } } private computeBackup(batch: ExperienceTensor) { @@ -291,12 +288,12 @@ export class SoftActorCritic { }) as tf.Tensor[] const targetQ1 = tf.squeeze( - this.targetQ1.apply(tf.concat([batch.nextObservation, action], 1)), + this.targetQ1.apply(tf.concat([batch.nextObservation, action], 1)) as tf.Tensor2D, [-1], ) as tf.Tensor const targetQ2 = tf.squeeze( - this.targetQ2.apply(tf.concat([batch.nextObservation, action], 1)), + this.targetQ2.apply(tf.concat([batch.nextObservation, action], 1)) as tf.Tensor2D, [-1], ) as tf.Tensor