diff --git a/packages/learning/src/main copy.ts b/packages/learning/src/main copy.ts new file mode 100644 index 00000000..f8feb496 --- /dev/null +++ b/packages/learning/src/main copy.ts @@ -0,0 +1,692 @@ +import * as tf from "@tensorflow/tfjs" +import { Buffer } from "buffer" +import { EntityWith, MessageCollector } from "runtime-framework" +import { WorldModel } from "runtime/proto/world" +import { LevelCapturedMessage } from "runtime/src/core/level-capture/level-captured-message" +import { RocketDeathMessage } from "runtime/src/core/rocket/rocket-death-message" +import { RuntimeComponents } from "runtime/src/core/runtime-components" +import { Runtime, newRuntime } from "runtime/src/runtime" +import { Environment, PPO } from "./ppo/ppo" + +export class PolyburnEnvironment implements Environment { + private runtime: Runtime + private currentRotation: number + private nearestLevel: EntityWith + + private captureMessages: MessageCollector + private deathMessages: MessageCollector + + private bestDistance: number + private maxTime = 60 * 30 + private remainingTime = 60 * 30 + + private worldModel: any + + private touchedFlag = false + + constructor() { + const worldStr = + "CqAJCgZOb3JtYWwSlQkKDw0fhZ3BFR+FB0Id2w/JQBItDR+FtsEVgZUDQh3bD8lAJQAAEMItpHBhQjWuR9lBPR+Fm0FFAAAAQE0AAABAEi0Nrkc/QRVt5wZCHdsPyUAlAAD4QC2kcBZCNezRjUI94KMwP0UAAABATQAAAEASLQ2k8B5CFX9qWEEd2w/JQCUAAP5BLaRwFkI17NG9Qj3gozA/RQAAAEBNAAAAQBItDeyRm0IVPzWGQR3bD8lAJQCAjUItSOHsQTX26AVDPYTr6cBFAAAAQE0AAABAEi0Nw0XwQhUcd4lAHTMeejwlAIDnQi2kcA5CNfboMkM9EK6nv0UAAABATQAAAEASLQ2PYhxDFT813EEd2w/JQCUAAM9CLaRwbEI1AMAmQz0fhbFBRQAAAEBNAAAAQBItDcM15UIVYxBJQh3bD8lAJQAAeUItUrijQjXs0fpCPZDCM0JFAAAAQE0AAABAEi0N9WiFQhXVeIhCHdsPyUAlw7WBQi3sUY9CNcO1kUI9AACBQkUAAABATQAAAEAaTgpMpHA9wXE9ukHAwP8AAEAAPYCA/wAAtIBDAAD/AIDFAEBAQP8AgMgAAICA/wBAxgC+oKD/AABGAMf///8AV0dxQry8+QBSQPHA////ABpOCkyuR3FBSOHKQf/++ABAxgAA//3wAAA/QMT/++AAQEoAQv/3wAAAPkBF/++AAADHAD//3gAAgMYAAP/vgAAAAIDD////AKxGCq////8AGpcCCpQC9qjBQpqZJEL///8AMNEAOv///wDqy9pH////AOzHNML///8AAMIAx////wAAQkDE////AABFAL3///8AAELAx////wCARgBF////AEBGgMb///8AwEYAv////wAgSQBF////AOBIgMP///8A4EjAR////wAARYDE////AAC+oMj///8AAD8AAP///wAAAODK////AGBJAEf///8AwMTASP///wAgSQAA////AEBEwMb///8AAEOAQ////wBASQC/////AAA+wEj///8AwEqAw////wAAvMBL////AODIAAD///8AQMoAQP///wAAPgBI////ACDIAAD///8AgMCARv///wCAyQAA////AEBFgMb///8AGqcCCqQCpHAZQqRwOcH///8AmFgAwP///wCAxwhU////AGDK4E3///8AwM1gyf///wAAv+DI////AKBLAMP///8AADpgyf///wCARgAA////AAA6YMv///8AQMgAAP///wAAvuDJ////AIBFYMj///8AQMyAwf///wAAtMDG////AGDLAL3///8AOMAMSP///wAkxgCu////AADC4Mj///8AAMNARv///wBgyQAA////AEDHgMP///8AwMeAQf///wAAAEBM////ACDJAAD///8AgMMAx////wAAyoBC////AAC9AMb///8AgMTARf///wCAwIDB////AABFAML///8AAMgANP///wBAxEBG////AADHAAD///8AAMFAyP///wBgyEDE////ABomCiSPQopCcT2DQv/AjQAAxAAA/+R0AAAAAMT/kwAAAEQAAP+bAAASEgoGTm9ybWFsEggKBk5vcm1hbA==" + this.worldModel = WorldModel.decode(Buffer.from(worldStr, "base64")) + + this.runtime = newRuntime(this.worldModel, "Normal") + + this.currentRotation = 0 + + const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] + const rocketPosition = rocket.components.rigidBody.translation() + + this.captureMessages = this.runtime.factoryContext.messageStore.collect("levelCaptured") + this.deathMessages = this.runtime.factoryContext.messageStore.collect("rocketDeath") + + this.nearestLevel = this.runtime.factoryContext.store + .find("level") + .filter(level => level.components.level.captured === false) + .sort( + (a, b) => + Math.abs(a.components.level.flag.x - rocketPosition.x) - + Math.abs(b.components.level.flag.y - rocketPosition.x), + )[0] + + const { distance } = this.state() + this.bestDistance = distance + } + + inputFromAction(action: number[]) { + const input = { + rotation: this.currentRotation + action[0], + thrust: action[1] > 0 ? true : false, + } + + return input + } + + step(action: number | number[]): [number[], number, boolean] { + if (typeof action === "number") { + throw new Error("Wrong action type") + } + + this.remainingTime-- + + const input = this.inputFromAction(action) + this.currentRotation += action[0] + + this.runtime.step(input) + + const { distance, observation, velMag, angDiff } = this.state() + + let newTouch = false + + if (this.nearestLevel.components.level.inCapture) { + if (!this.touchedFlag) { + newTouch = true + } + + this.touchedFlag = true + } + + const captureMessage = [...this.captureMessages].at(-1) + + if (captureMessage) { + const reward = 10000 + (this.maxTime - this.remainingTime) * 100 + return [observation, reward, true] + } + + const deathMessage = [...this.deathMessages].at(-1) + + if (deathMessage) { + const reward = -(velMag + angDiff) + return [observation, reward, true] + } + + const reward = Math.max(0, this.bestDistance - distance) + this.bestDistance = Math.min(this.bestDistance, distance) + + const done = this.remainingTime <= 0 + + return [observation, reward * 10 + (newTouch ? 100 : 0), done] + } + + state() { + const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] + + const rocketPosition = rocket.components.rigidBody.translation() + const rocketRotation = rocket.components.rigidBody.rotation() + const rocketVelocity = rocket.components.rigidBody.linvel() + + const dx = this.nearestLevel.components.level.flag.x - rocketPosition.x + const dy = this.nearestLevel.components.level.flag.y - rocketPosition.y + + const distanceToLevel = Math.sqrt(dx * dx + dy * dy) + + const angDiff = + (this.nearestLevel.components.level.flagRotation - + rocket.components.rigidBody.rotation()) % + (Math.PI * 2) + + const velMag = Math.sqrt( + rocketVelocity.x * rocketVelocity.x + rocketVelocity.y * rocketVelocity.y, + ) + + return { + distance: distanceToLevel, + observation: [ + this.nearestLevel.components.level.flag.x - rocketPosition.x, + this.nearestLevel.components.level.flag.y - rocketPosition.y, + rocketRotation, + rocketVelocity.x, + rocketVelocity.y, + ], + touched: this.touchedFlag, + angDiff, + velMag, + } + } + + reset(): number[] { + this.runtime = newRuntime(this.worldModel, "Normal") + + this.currentRotation = 0 + + const rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0] + const rocketPosition = rocket.components.rigidBody.translation() + + this.captureMessages = this.runtime.factoryContext.messageStore.collect("levelCaptured") + this.deathMessages = this.runtime.factoryContext.messageStore.collect("rocketDeath") + + this.nearestLevel = this.runtime.factoryContext.store + .find("level") + .filter(level => level.components.level.captured === false) + .sort( + (a, b) => + Math.abs(a.components.level.flag.x - rocketPosition.x) - + Math.abs(b.components.level.flag.y - rocketPosition.x), + )[0] + + const { distance, observation } = this.state() + + this.bestDistance = distance + this.remainingTime = this.maxTime + this.touchedFlag = false + + return observation + } +} + +export class CartPole implements Environment { + private gravity: number + private massCart: number + private massPole: number + private totalMass: number + private cartWidth: number + private cartHeight: number + private length: number + private poleMoment: number + private forceMag: number + private tau: number + + private xThreshold: number + private thetaThreshold: number + + private x: number = 0 + private xDot: number = 0 + private theta: number = 0 + private thetaDot: number = 0 + + /** + * Constructor of CartPole. + */ + constructor() { + // Constants that characterize the system. + this.gravity = 9.8 + this.massCart = 1.0 + this.massPole = 0.1 + this.totalMass = this.massCart + this.massPole + this.cartWidth = 0.2 + this.cartHeight = 0.1 + this.length = 0.5 + this.poleMoment = this.massPole * this.length + this.forceMag = 10.0 + this.tau = 0.02 // Seconds between state updates. + + // Threshold values, beyond which a simulation will be marked as failed. + this.xThreshold = 2.4 + this.thetaThreshold = (12 / 360) * 2 * Math.PI + + this.reset() + } + + /** + * Get current state as a tf.Tensor of shape [1, 4]. + */ + getStateTensor() { + return [this.x, this.xDot, this.theta, this.thetaDot] + } + + private i = 0 + private max = 0 + + /** + * Update the cart-pole system using an action. + * @param {number} action Only the sign of `action` matters. + * A value > 0 leads to a rightward force of a fixed magnitude. + * A value <= 0 leads to a leftward force of the same fixed magnitude. + */ + step(action: number | number[]): [number[], number, boolean] { + if (Array.isArray(action)) { + action = action[0] + } + + const force = action * this.forceMag + + const cosTheta = Math.cos(this.theta) + const sinTheta = Math.sin(this.theta) + + const temp = + (force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass + const thetaAcc = + (this.gravity * sinTheta - cosTheta * temp) / + (this.length * (4 / 3 - (this.massPole * cosTheta * cosTheta) / this.totalMass)) + const xAcc = temp - (this.poleMoment * thetaAcc * cosTheta) / this.totalMass + + // Update the four state variables, using Euler's method. + this.x += this.tau * this.xDot + this.xDot += this.tau * xAcc + this.theta += this.tau * this.thetaDot + this.thetaDot += this.tau * thetaAcc + + const reward = this.isDone() ? -100 : 1 + return [this.getStateTensor(), reward, this.isDone()] + } + + /** + * Set the state of the cart-pole system randomly. + */ + reset() { + this.i = 0 + // The control-theory state variables of the cart-pole system. + // Cart position, meters. + this.x = Math.random() - 0.5 + // Cart velocity. + this.xDot = (Math.random() - 0.5) * 1 + // Pole angle, radians. + this.theta = (Math.random() - 0.5) * 2 * ((6 / 360) * 2 * Math.PI) + // Pole angle velocity. + this.thetaDot = (Math.random() - 0.5) * 0.5 + + return this.getStateTensor() + } + + /** + * Determine whether this simulation is done. + * + * A simulation is done when `x` (position of the cart) goes out of bound + * or when `theta` (angle of the pole) goes out of bound. + * + * @returns {bool} Whether the simulation is done. + */ + isDone() { + return ( + this.x < -this.xThreshold || + this.x > this.xThreshold || + this.theta < -this.thetaThreshold || + this.theta > this.thetaThreshold + ) + } +} + +import "@tensorflow/tfjs-backend-webgl" +import "@tensorflow/tfjs-backend-webgpu" +import { SoftActorCritic } from "./soft-actor-critic/soft-actor-critic" + +if (false) { + tf.setBackend("cpu").then(() => { + const sac = new SoftActorCritic({ + mlpSpec: { + sizes: [32, 32], + activation: "relu", + outputActivation: "relu", + }, + + actionSize: 1, + observationSize: 4, + + maxEpisodeLength: 1000, + bufferSize: 1e6, + batchSize: 100, + updateAfter: 1000, + updateEvery: 50, + + learningRate: 1e-3, + alpha: 0.2, + gamma: 0.99, + polyak: 0.995, + }) + + sac.test() + + /* + const actor = new Actor(4, 2, { + sizes: [32, 32], + activation: "relu", + outputActivation: "relu", + }) + + actor.trainableWeights.forEach(w => { + w.write(tf.zeros(w.shape, w.dtype)) + }) + + /* + x = torch.tensor([[0.1, 0.2, 0.3, 0.4]], dtype=torch.float32) + x = actor(x, True) + + const x = tf.tensor2d([[0.1, 0.2, 0.3, 0.4]]) + const r = actor.apply(x, { deterministic: true }) as tf.Tensor[] + + console.log(r[0].dataSync()) + console.log(r[1].dataSync()) + */ + }) +} + +if (true) { + tf.setBackend("cpu").then(() => { + const env = new CartPole() + + const sac = new SoftActorCritic({ + mlpSpec: { + sizes: [32, 32], + activation: "relu", + outputActivation: "relu", + }, + + actionSize: 1, + observationSize: 4, + + maxEpisodeLength: 1000, + bufferSize: 1e6, + batchSize: 100, + updateAfter: 1000, + updateEvery: 50, + + learningRate: 1e-3, + alpha: 0.2, + gamma: 0.99, + polyak: 0.995, + }) + + function currentReward() { + const acc = [] + + for (let j = 0; j < 10; ++j) { + env.reset() + + let t = 0 + + while (!env.isDone() && t < 1000) { + env.step(sac.act(env.getStateTensor(), true)) + t++ + } + + acc.push(t) + } + + // average of top 10% lifetimes + acc.sort((a, b) => b - a) + + const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 + const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 + const avg = acc.reduce((a, b) => a + b, 0) / acc.length + + return { avg, best10avg, worst10avg } + } + + let t = 0 + let updated = false + + function iteration() { + for (let i = 0; i < 16; ++i) { + t++ + + const observation = env.getStateTensor() + + let action: number[] + + if (t < 1000) { + action = [Math.random()] + } else { + action = sac.act(observation, false) + } + + const [nextObservation, reward, done] = env.step(action) + + const thisTimeUpdated = sac.observe({ + observation, + action, + reward, + nextObservation, + done, + }) + + updated ||= thisTimeUpdated + + if (done) { + if (updated) { + const { avg, best10avg, worst10avg } = currentReward() + + console.log(`Leaks: ${tf.memory().numTensors}`) + console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + } + + env.reset() + + updated = false + } + } + + requestAnimationFrame(iteration) + } + + console.log("Start") + requestAnimationFrame(iteration) + + /* + const ppo = new PPO( + { + steps: 512, + epochs: 15, + policyLearningRate: 1e-3, + valueLearningRate: 1e-3, + clipRatio: 0.1, + targetKL: 0.01, + gamma: 0.99, + lambda: 0.95, + observationDimension: 4, + actionSpace: { + class: "Discrete", + len: 2, + }, + }, + env, + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: 4, + units: 32, + activation: "relu", + }), + tf.layers.dense({ + units: 32, + activation: "relu", + }), + ], + }), + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: 4, + units: 32, + activation: "relu", + }), + tf.layers.dense({ + units: 32, + activation: "relu", + }), + ], + }), + ) + + function possibleLifetime() { + const acc = [] + + for (let j = 0; j < 25; ++j) { + env.reset() + + let t = 0 + + while (!env.isDone() && t < 1000) { + env.step(ppo.act(env.getStateTensor()) as number[]) + t++ + } + + acc.push(t) + } + + // average of top 10% lifetimes + acc.sort((a, b) => b - a) + + const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 + const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 + const avg = acc.reduce((a, b) => a + b, 0) / acc.length + + return { avg, best10avg, worst10avg } + } + + let currentAverage = 0 + let i = 0 + + function iteration() { + ppo.learn(512 * i) + + const { avg, best10avg, worst10avg } = possibleLifetime() + + console.log(`Leaks: ${tf.memory().numTensors}`) + console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + + if (avg > currentAverage) { + // await ppo.save() + currentAverage = avg + console.log("Saved") + } + + i++ + + requestAnimationFrame(iteration) + } + + console.log("Initial: ", possibleLifetime()) + + console.log("Start") + requestAnimationFrame(iteration) + + */ + }) +} + +if (false) { + tf.setBackend("cpu").then(() => { + const env = new PolyburnEnvironment() + + const inputDim = 5 + + const ppo = new PPO( + { + steps: 512, + epochs: 15, + policyLearningRate: 1e-3, + valueLearningRate: 1e-3, + clipRatio: 0.2, + targetKL: 0.01, + gamma: 0.99, + lambda: 0.95, + observationDimension: inputDim, + actionSpace: { + class: "Box", + len: 2, + low: -1, + high: 1, + }, + }, + env, + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: inputDim, + units: 64, + activation: "relu", + }), + tf.layers.dense({ + units: 64, + activation: "relu", + }), + ], + }), + tf.sequential({ + layers: [ + tf.layers.dense({ + inputDim: inputDim, + units: 64, + activation: "relu", + }), + tf.layers.dense({ + units: 64, + activation: "relu", + }), + ], + }), + ) + + function possibleLifetime() { + let observation = env.reset() + + let totalReward = 0 + const inputs = [] + + while (true) { + const action = ppo.act(observation) + inputs.push(env.inputFromAction(action as number[])) + + const [nextObservation, reward, done] = env.step(action) + + totalReward += reward + observation = nextObservation + + if (done) { + break + } + } + + return { + totalReward, + touched: env.state().touched, + distance: env.state().distance, + inputs, + } + } + + let currentAverage = 0 + let i = 0 + + const previousTwenty: number[] = [] + + function iteration() { + ppo.learn(512 * i) + const info = possibleLifetime() + + console.log( + `Reward ${i}: reward(${info.totalReward}), distance(${info.distance}), touched(${info.touched})`, + ) + + if (info.totalReward > currentAverage && previousTwenty.length === 20) { + currentAverage = info.totalReward + console.log("Saved") + ppo.save() + } + + if (previousTwenty.length === 20) { + previousTwenty.shift() + } + + previousTwenty.push(info.totalReward) + + const avgPreviousTwenty = + previousTwenty.reduce((a, b) => a + b, 0) / previousTwenty.length + + ++i + + if ( + avgPreviousTwenty < 50 && + avgPreviousTwenty < Math.max(currentAverage, 10) * 0.5 && + previousTwenty.length === 20 + ) { + console.log("Restoring") + + ppo.restore().finally(() => { + requestAnimationFrame(iteration) + }) + } else { + requestAnimationFrame(iteration) + } + } + + ppo.restore().finally(() => { + const { totalReward, inputs } = possibleLifetime() + currentAverage = totalReward + + console.log(JSON.stringify(inputs)) + + console.log("Start with: ", currentAverage) + requestAnimationFrame(iteration) + }) + }) +} diff --git a/packages/learning/src/main.ts b/packages/learning/src/main.ts index 6682083e..71040b7f 100644 --- a/packages/learning/src/main.ts +++ b/packages/learning/src/main.ts @@ -241,6 +241,7 @@ export class CartPole implements Environment { const cosTheta = Math.cos(this.theta) const sinTheta = Math.sin(this.theta) + ++this.i const temp = (force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass @@ -255,7 +256,14 @@ export class CartPole implements Environment { this.theta += this.tau * this.thetaDot this.thetaDot += this.tau * thetaAcc - const reward = this.isDone() ? -100 : 1 + let reward = 0 + + if (this.isDone()) { + reward = -100 + } else { + reward = 1 + } + return [this.getStateTensor(), reward, this.isDone()] } @@ -299,8 +307,10 @@ import "@tensorflow/tfjs-backend-webgl" import "@tensorflow/tfjs-backend-webgpu" import { SoftActorCritic } from "./soft-actor-critic/soft-actor-critic" -if (true) { +if (false) { tf.setBackend("cpu").then(() => { + const env = new CartPole() + const sac = new SoftActorCritic({ mlpSpec: { sizes: [32, 32], @@ -312,18 +322,83 @@ if (true) { observationSize: 4, maxEpisodeLength: 1000, - bufferSize: 10000, + bufferSize: 1e6, batchSize: 100, - updateAfter: 10000, + updateAfter: 1000, updateEvery: 50, - learningRate: 0.01, + learningRate: 1e-3, alpha: 0.2, gamma: 0.99, polyak: 0.995, }) - sac.test() + sac.learn(new CartPole()) + + function iteration() { + requestAnimationFrame(iteration) + } + + requestAnimationFrame(iteration) + + return + fetch("http://localhost:5173/batches.json") + .then(r => + r + .json() + .then(j => { + const batches = JSON.parse(j) + let i = 0 + + function currentReward() { + const acc = [] + + for (let j = 0; j < 100; ++j) { + env.reset() + + let x = 0 + + while (!env.isDone() && x < 1000) { + env.step(sac.act(env.getStateTensor(), true)) + x++ + } + + acc.push(x) + } + + // average of top 10% lifetimes + acc.sort((a, b) => b - a) + + const best10avg = acc.slice(0, 10).reduce((a, b) => a + b, 0) / 10 + const worst10avg = acc.slice(-10).reduce((a, b) => a + b, 0) / 10 + const avg = acc.reduce((a, b) => a + b, 0) / acc.length + + return { avg, best10avg, worst10avg } + } + + for (const batch of batches) { + sac.update({ + observation: tf.tensor2d(batch.observation), + action: tf.tensor2d(batch.action), + reward: tf.tensor1d(batch.reward), + nextObservation: tf.tensor2d(batch.nextObservation), + done: tf.tensor1d(batch.done), + }) + + console.log(`Batch ${i++} done`) + } + + console.log("Reward: ", currentReward()) + + console.log("Done") + }) + .catch(e => { + console.error(e) + }), + ) + .catch(e => { + console.error(e) + }) /* const actor = new Actor(4, 2, { @@ -349,8 +424,8 @@ if (true) { }) } -if (false) { - tf.setBackend("webgpu").then(() => { +if (true) { + tf.setBackend("cpu").then(() => { const env = new CartPole() const sac = new SoftActorCritic({ @@ -364,10 +439,10 @@ if (false) { observationSize: 4, maxEpisodeLength: 1000, - bufferSize: 100000, - batchSize: 8096, - updateAfter: 8096, - updateEvery: 25, + bufferSize: 1e6, + batchSize: 100, + updateAfter: 1000, + updateEvery: 50, learningRate: 1e-3, alpha: 0.2, @@ -378,17 +453,17 @@ if (false) { function currentReward() { const acc = [] - for (let j = 0; j < 25; ++j) { + for (let j = 0; j < 10; ++j) { env.reset() - let t = 0 + let x = 0 - while (!env.isDone() && t < 1000) { - env.step(sac.act(env.getStateTensor(), true)) - t++ + while (!env.isDone() && x < 1000) { + env.step(sac.act(env.getStateTensor(), false)) + x++ } - acc.push(t) + acc.push(x) } // average of top 10% lifetimes @@ -402,25 +477,25 @@ if (false) { } let t = 0 - let k = 8 + let updated = false function iteration() { - for (let i = 0; i < 128; ++i) { + for (let i = 0; i < 16; ++i) { t++ const observation = env.getStateTensor() let action: number[] - if (t < 300) { - action = [Math.random()] + if (t < 10_000) { + action = [Math.random() * 2 - 1] } else { action = sac.act(observation, false) } const [nextObservation, reward, done] = env.step(action) - sac.observe({ + const thisTimeUpdated = sac.observe({ observation, action, reward, @@ -428,21 +503,20 @@ if (false) { done, }) - if (done) { - env.reset() - } - } + updated ||= thisTimeUpdated - --k - console.log(k) + if (done) { + if (updated) { + const { avg, best10avg, worst10avg } = currentReward() - if (k < 0) { - k = 10 + console.log(`Leaks: ${tf.memory().numTensors}`) + console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + } - const { avg, best10avg, worst10avg } = currentReward() + env.reset() - console.log(`Leaks: ${tf.memory().numTensors}`) - console.log(`10%: ${best10avg}, 90%: ${worst10avg}, avg: ${avg}`) + updated = false + } } requestAnimationFrame(iteration) diff --git a/packages/learning/src/soft-actor-critic/replay-buffer.ts b/packages/learning/src/soft-actor-critic/replay-buffer.ts index 5067c235..094389e6 100644 --- a/packages/learning/src/soft-actor-critic/replay-buffer.ts +++ b/packages/learning/src/soft-actor-critic/replay-buffer.ts @@ -17,43 +17,70 @@ export interface ExperienceTensor { } export class ReplayBuffer { - private buffer: Experience[] = [] - private bufferIndex = 0 + private buffer: Experience[] + private ptr: number + private size: number + + private indices: number[] constructor( private capacity: number, private batchSize: number, - private observationSize: number, - private actionSize: number, ) { if (batchSize > capacity) { throw new Error("Batch size must be less than or equal to capacity") } + + this.buffer = [] + this.ptr = 0 + this.size = 0 + + this.indices = [] } - push(experience: Experience) { - if (this.buffer.length >= this.capacity) { - this.buffer[this.bufferIndex] = experience - this.bufferIndex = (this.bufferIndex + 1) % this.capacity - } else { + store(experience: Experience) { + if (this.size < this.capacity) { this.buffer.push({ ...experience }) + this.indices.push(this.size) + } else { + this.buffer[this.ptr] = experience } + + this.ptr = (this.ptr + 1) % this.capacity + this.size = Math.min(this.size + 1, this.capacity) } sample(): ExperienceTensor { - if (this.buffer.length < this.batchSize) { + if (this.size < this.batchSize) { throw new Error("Buffer does not have enough experiences") } - const indices = [ - ...tf.util.createShuffledIndices(Math.min(this.buffer.length, this.batchSize)), - ] + tf.util.shuffle(this.indices) + const indices = this.indices.slice(0, this.batchSize) - const observation = tf.tensor2d(indices.map(x => this.buffer[x].observation)) - const action = tf.tensor2d(indices.map(x => this.buffer[x].action)) - const reward = tf.tensor1d(indices.map(x => this.buffer[x].reward)) - const nextObservation = tf.tensor2d(indices.map(x => this.buffer[x].nextObservation)) - const done = tf.tensor1d(indices.map(x => (this.buffer[x].done ? 1 : 0))) + const observation = tf.tensor2d( + indices.map(x => this.buffer[x].observation), + undefined, + "float32", + ) + const action = tf.tensor2d( + indices.map(x => this.buffer[x].action), + undefined, + "float32", + ) + const reward = tf.tensor1d( + indices.map(x => this.buffer[x].reward), + "float32", + ) + const nextObservation = tf.tensor2d( + indices.map(x => this.buffer[x].nextObservation), + undefined, + "float32", + ) + const done = tf.tensor1d( + indices.map(x => (this.buffer[x].done ? 1 : 0)), + "float32", + ) return { observation, action, reward, nextObservation, done } } diff --git a/packages/learning/src/soft-actor-critic/soft-actor-critic.ts b/packages/learning/src/soft-actor-critic/soft-actor-critic.ts index e414c350..61ef71b7 100644 --- a/packages/learning/src/soft-actor-critic/soft-actor-critic.ts +++ b/packages/learning/src/soft-actor-critic/soft-actor-critic.ts @@ -1,4 +1,5 @@ import * as tf from "@tensorflow/tfjs" +import { Environment } from "../ppo/ppo" import { actor } from "./actor" import { critic } from "./critic" import { MlpSpecification } from "./mlp" @@ -22,6 +23,72 @@ export interface Config { polyak: number } +/* +class ReplayBuffer: + """ + A simple FIFO experience replay buffer for SAC agents. + """ + + def __init__(self, obs_dim, act_dim, size): + self.obs_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32) + self.obs2_buf = np.zeros(core.combined_shape(size, obs_dim), dtype=np.float32) + self.act_buf = np.zeros(core.combined_shape(size, act_dim), dtype=np.float32) + self.rew_buf = np.zeros(size, dtype=np.float32) + self.done_buf = np.zeros(size, dtype=np.float32) + self.ptr, self.size, self.max_size = 0, 0, size + + def store(self, obs, act, rew, next_obs, done): + self.obs_buf[self.ptr] = obs + self.obs2_buf[self.ptr] = next_obs + self.act_buf[self.ptr] = act + self.rew_buf[self.ptr] = rew + self.done_buf[self.ptr] = done + self.ptr = (self.ptr+1) % self.max_size + self.size = min(self.size+1, self.max_size) + + def sample_batch(self, batch_size=32): + idxs = np.random.randint(0, self.size, size=batch_size) + batch = dict(obs=self.obs_buf[idxs], + obs2=self.obs2_buf[idxs], + act=self.act_buf[idxs], + rew=self.rew_buf[idxs], + done=self.done_buf[idxs]) + + batches.append( + dict( + observation=self.obs_buf[idxs], + nextObservation=self.obs2_buf[idxs], + action=self.act_buf[idxs], + reward=self.rew_buf[idxs], + done=self.done_buf[idxs] + ) + ) + + return {k: torch.as_tensor(v, dtype=torch.float32) for k,v in batch.items()} +*/ + +export interface LearningProgress { + episodeReturn: number + episodeLength: number + + maxEpisodeReturn: number + maxEpisodeLength: number + + lossQ: number + lossPolicy: number +} + +export interface SacLearningConfig { + startAfter: number + renderEvery: number + + onStart?: () => void + onStep?: () => void + onUpdate?: () => void + onEnd?: () => void + onEpisodeEnd?: () => void +} + export class SoftActorCritic { private replayBuffer: ReplayBuffer @@ -40,12 +107,7 @@ export class SoftActorCritic { private t: number constructor(private config: Config) { - this.replayBuffer = new ReplayBuffer( - config.bufferSize, - config.batchSize, - config.observationSize, - config.actionSize, - ) + this.replayBuffer = new ReplayBuffer(config.bufferSize, config.batchSize) this.policy = actor(config.observationSize, config.actionSize, config.mlpSpec) this.q1 = critic(config.observationSize, config.actionSize, config.mlpSpec) @@ -58,23 +120,63 @@ export class SoftActorCritic { this.episodeLength = 0 this.t = 0 - this.observationBuffer = tf.buffer([1, config.observationSize], "float32") - this.policyOptimizer = tf.train.adam(config.learningRate) this.qOptimizer = tf.train.adam(config.learningRate) + } - for (let i = 0; i < this.targetQ1.trainableWeights.length; i++) { - const targetWeight = this.targetQ1.trainableWeights[i] - const weight = this.q1.trainableWeights[i] + async learn(env: Environment, sacConfig: SacLearningConfig) { + const buffer = new ReplayBuffer(this.config.bufferSize, this.config.batchSize) - targetWeight.write(weight.read()) - } + let observation = env.reset() + let episodeLength = 0 + let episodeReturn = 0 - for (let i = 0; i < this.targetQ2.trainableWeights.length; i++) { - const targetWeight = this.targetQ2.trainableWeights[i] - const weight = this.q2.trainableWeights[i] + let maxEpisodeReturn = 0 + let maxEpisodeLength = 0 - targetWeight.write(weight.read()) + for (let t = 0; ; ++t) { + let action: number[] + + if (t > startFrom) { + action = this.act(observation, false) + } else { + action = [Math.random() * 2 - 1] + } + + const [nextObservation, reward, done] = env.step(action) + episodeLength += 1 + + buffer.store({ + observation, + action, + reward, + nextObservation, + done: episodeLength === this.config.maxEpisodeLength ? false : done, + }) + + observation = nextObservation + + if (done || episodeLength === this.config.maxEpisodeLength) { + observation = env.reset() + episodeLength = 0 + } + + if (t >= this.config.updateAfter && t % this.config.updateEvery === 0) { + for (let i = 0; i < this.config.updateEvery; i++) { + const batch = buffer.sample() + this.update(batch, false) + + tf.dispose(batch.observation) + tf.dispose(batch.action) + tf.dispose(batch.reward) + tf.dispose(batch.nextObservation) + tf.dispose(batch.done) + } + } + + if (renderEvery > 0 && t % renderEvery === 0) { + await tf.nextFrame() + } } } @@ -95,9 +197,8 @@ export class SoftActorCritic { const done = this.episodeLength < this.config.maxEpisodeLength && experience.done - this.replayBuffer.push({ + this.replayBuffer.store({ ...experience, - done, }) if (done || this.episodeLength === this.config.maxEpisodeLength) { @@ -106,13 +207,16 @@ export class SoftActorCritic { } if (this.t > this.config.updateAfter && this.t % this.config.updateEvery === 0) { - console.log("update") for (let i = 0; i < this.config.updateEvery; i++) { tf.tidy(() => { - this.update(this.replayBuffer.sample()) + this.update(this.replayBuffer.sample(), i === this.config.updateEvery - 1) }) } + + return true } + + return false } predictQ1(observation: tf.Tensor2D, action: tf.Tensor2D) { @@ -129,189 +233,56 @@ export class SoftActorCritic { ) as tf.Tensor } - test() { - this.deterministic = true - - this.q1.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) - this.q2.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) - this.policy.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) - this.targetQ1.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) - this.targetQ2.weights.forEach(w => w.write(tf.ones(w.shape).mul(0.2))) - - let seedi = 9 - - function randomNumber(seed: number, min: number, max: number) { - const a = 1103515245 - const c = 721847 - - seed = (a * seed + c) % 2 ** 31 - return min + (seed % (max - min)) - } - - /* - seedi = 9 - - for i in range(10): - seedi += 4 - observation = torch.tensor([[ - randomNumber(seedi, -10, 10), - randomNumber(seedi + 1, -10, 10), - randomNumber(seedi + 2, -10, 10), - randomNumber(seedi + 3, -10, 10)] - ]) - - print("R", i, ": ", ac.pi(observation, True)) - */ - - /* - seedi = 9 - - for (let i = 0; i < 10; i++) { - seedi += 4 - const observation = tf.tensor2d([ - [ - randomNumber(seedi, -10, 10), - randomNumber(seedi + 1, -10, 10), - randomNumber(seedi + 2, -10, 10), - randomNumber(seedi + 3, -10, 10), - ], - ]) - - console.log(observation.dataSync()) - const [a, b] = this.policy.apply(observation, { - deterministic: true, - }) - console.log("R", i, ": ", a.dataSync(), b.dataSync()) - } - */ - function randomData() { - /* - def randomNumber(seed, min, max): - a = 1103515245 - c = 721847 - - seed = (a * seed + c) % 2**31 - return (float) (min + (seed % (max - min))) - - seedi = 9 - - def randomData(): - global seedi - - seedi += 5 - data = { - 'obs': torch.tensor([[randomNumber(seedi, -10, 10), randomNumber(2, -10, 10), randomNumber(3, -10, 10), randomNumber(4, -10, 10)]]), - 'act': torch.tensor([[randomNumber(seedi + 1, -1, 1)]]), - 'rew': torch.tensor([[randomNumber(seedi + 2, -100, 100)]]), - 'obs2': torch.tensor([[randomNumber(seedi + 3, -10, 10), randomNumber(8, -10, 10), randomNumber(9, -10, 10), randomNumber(10, -10, 10)]]), - 'done': torch.tensor([[randomNumber(seedi + 4, 0, 1)]]) - } - - return data - */ - - seedi += 5 - - return { - observation: tf.tensor2d([ - [ - randomNumber(seedi, -10, 10), - randomNumber(2, -10, 10), - randomNumber(3, -10, 10), - randomNumber(4, -10, 10), - ], - ]), - action: tf.tensor2d([[randomNumber(seedi + 1, -1, 1)]]), - reward: tf.tensor1d([randomNumber(seedi + 2, -100, 100)]), - nextObservation: tf.tensor2d([ - [ - randomNumber(seedi + 3, -10, 10), - randomNumber(8, -10, 10), - randomNumber(9, -10, 10), - randomNumber(10, -10, 10), - ], - ]), - done: tf.tensor1d([randomNumber(seedi + 4, 0, 1)]), - } - } - - let data = randomData() - for (let i = 0; i < 1000; i++) { - console.log("Action: ", this.act(data.observation.arraySync()[0], true)[0]) - this.update(data) - data = randomData() - } - - this.q1.trainableWeights.forEach(w => { - console.log(w.read().dataSync()) - }) - - console.log("Verify: ", randomNumber(seedi, 0, 1000)) - } - private deterministic = false - private update(batch: ExperienceTensor) { - tf.tidy(() => { - const lossQ = () => { - const q1 = this.predictQ1(batch.observation, batch.action) - const q2 = this.predictQ2(batch.observation, batch.action) - - const backup = tf.tensor1d(this.computeBackup(batch).arraySync()) + update(batch: ExperienceTensor, last: boolean = false) { + const lossQ = () => { + const q1 = this.predictQ1(batch.observation, batch.action) + const q2 = this.predictQ2(batch.observation, batch.action) - const errorQ1 = tf.mean(tf.square(tf.sub(q1, backup))) as tf.Scalar - const errorQ2 = tf.mean(tf.square(tf.sub(q2, backup))) as tf.Scalar - - return tf.add(errorQ1, errorQ2) as tf.Scalar - } + const backup = tf.tensor1d(this.computeBackup(batch).arraySync()) - console.log("lossQ: ", lossQ().arraySync()) + const errorQ1 = tf.mean(tf.square(tf.sub(q1, backup))) as tf.Scalar + const errorQ2 = tf.mean(tf.square(tf.sub(q2, backup))) as tf.Scalar - const gradsQ = tf.variableGrads(lossQ) - this.qOptimizer.applyGradients(gradsQ.grads) + return tf.add(errorQ1, errorQ2) as tf.Scalar + } - const lossPolicy = () => { - const [pi, logpPi] = this.policy.apply(batch.observation, { - deterministic: this.deterministic, - }) as tf.Tensor[] + const gradsQ = tf.variableGrads(lossQ, this.q1.getWeights().concat(this.q2.getWeights())) + this.qOptimizer.applyGradients(gradsQ.grads) + tf.dispose(gradsQ) - const piQ1 = this.predictQ1(batch.observation, pi) - const piQ2 = this.predictQ2(batch.observation, pi) + const lossPolicy = () => { + const [pi, logpPi] = this.policy.apply(batch.observation, { + deterministic: this.deterministic, + }) as tf.Tensor[] - const minPiQ = tf.minimum(piQ1, piQ2) + const piQ1 = this.predictQ1(batch.observation, pi) + const piQ2 = this.predictQ2(batch.observation, pi) - return tf.mean(logpPi.mul(this.config.alpha).sub(minPiQ)) as tf.Scalar - } + const minPiQ = tf.minimum(piQ1, piQ2) - console.log("lossPolicy: ", lossPolicy().arraySync()) + return tf.mean(logpPi.mul(this.config.alpha).sub(minPiQ)) as tf.Scalar + } - const gradsPolicy = tf.variableGrads(lossPolicy, this.policy.getWeights()) - this.policyOptimizer.applyGradients(gradsPolicy.grads) + const gradsPolicy = tf.variableGrads(lossPolicy, this.policy.getWeights()) + this.policyOptimizer.applyGradients(gradsPolicy.grads) + tf.dispose(gradsPolicy) - for (let i = 0; i < this.targetQ1.trainableWeights.length; i++) { - const targetWeight = this.targetQ1.trainableWeights[i] - const weight = this.q1.trainableWeights[i] + // do polyak averaging + const tau = tf.scalar(this.config.polyak) + const oneMinusTau = tf.scalar(1).sub(tau) - targetWeight.write( - tf.add( - tf.mul(this.config.polyak, targetWeight.read()), - tf.mul(1 - this.config.polyak, weight.read()), - ), - ) - } + const updateTargetQ1 = this.targetQ1 + .getWeights() + .map((w, i) => w.mul(tau).add(this.q1.getWeights()[i].mul(oneMinusTau))) - for (let i = 0; i < this.targetQ2.trainableWeights.length; i++) { - const targetWeight = this.targetQ2.trainableWeights[i] - const weight = this.q2.trainableWeights[i] + const updateTargetQ2 = this.targetQ2 + .getWeights() + .map((w, i) => w.mul(tau).add(this.q2.getWeights()[i].mul(oneMinusTau))) - targetWeight.write( - tf.add( - tf.mul(this.config.polyak, targetWeight.read()), - tf.mul(1 - this.config.polyak, weight.read()), - ), - ) - } - }) + this.targetQ1.setWeights(updateTargetQ1) + this.targetQ2.setWeights(updateTargetQ2) } private computeBackup(batch: ExperienceTensor) {