Skip to content

Commit

Permalink
Continue prototyping learning
Browse files Browse the repository at this point in the history
  • Loading branch information
phisn committed Apr 27, 2024
1 parent e8d5127 commit 782fee8
Show file tree
Hide file tree
Showing 11 changed files with 676 additions and 148 deletions.
11 changes: 8 additions & 3 deletions packages/learning/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
"name": "learning",
"private": true,
"version": "1.0.0",
"type": "commonjs",
"moduleResolution": "node",
"scripts": {
"rl-dev": "vite",
"rl-local": "npx ts-node ./src/main.ts",
"rl-local": "npx tsx ./src/main.ts",
"rl-build": "tsc && vite build",
"rl-preview": "vite preview"
},
Expand All @@ -18,17 +19,21 @@
"runtime-framework": "*",
"shared": "*",
"tailwindcss": "^3.4.3",
"ts-node": "^10.9.2",
"tslib": "^2.6.2",
"tsx": "^4.7.3",
"typescript": "^5.4.2",
"vite": "^5.1.6"
},
"dependencies": {
"@tensorflow/tfjs": "^4.18.0",
"@tensorflow/tfjs-node": "^4.18.0",
"@tensorflow/tfjs-node-gpu": "^4.18.0",
"@types/prompts": "^2.4.9",
"@types/sat": "^0.0.35",
"eslint-config-custom": "*",
"lil-gui": "^0.19.2",
"poly-decomp-es": "^0.4.2",
"ppo-tfjs": "^0.0.2",
"prompts": "^2.4.2",
"sat": "^0.9.0",
"three": "^0.162.0",
"tsconfig": "*",
Expand Down
208 changes: 174 additions & 34 deletions packages/learning/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
// 1 -> 0
// 0 -> -1
// -1 -> 1

import { SoftActorCritic } from "./soft-actor-critic/soft-actor-critic"

function getReward(got: number, expected: number) {
const gotRounded = Math.round(got)
function f() {
const gotRounded = Math.round(got)

if (gotRounded === expected) {
return 0
}
if (gotRounded === expected) {
return 0
}

if (gotRounded === 0) {
return expected === -1 ? 1 : -1
}
if (gotRounded === 0) {
return expected === -1 ? 1 : -1
}

if (gotRounded === 1) {
return expected === 0 ? 1 : -1
}

if (gotRounded === 1) {
return expected === 0 ? 1 : -1
return expected === 1 ? 1 : -1
}

return expected === 1 ? 1 : -1
return (f() + 1) / 2
}

const observationSize = 8
Expand All @@ -36,26 +34,168 @@ const observations = [
[[1, -1, 1, -1, 1, -1, 1, -1], [1]],
]

const sac = new SoftActorCritic({
mlpSpec: {
sizes: [64, 64],
activation: "relu",
outputActivation: "relu",
},
actionSize,
observationSize,
maxEpisodeLength: 1000,
bufferSize: 10000,
batchSize: 64,
updateAfter: 1000,
updateEvery: 50,
learningRate: 0.001,
alpha: 0.2,
gamma: 0.99,
const PPO = require("ppo-tfjs")

export class CartPole {
actionSpace = {
class: "Box",
shape: [1],
low: -1,
high: 1,
}

observationSpace = {
class: "Box",
shape: [4],
low: [-4.8, -Infinity, -0.418, -Infinity],
high: [4.8, Infinity, 0.418, Infinity],
}

private gravity: number
private massCart: number
private massPole: number
private totalMass: number
private cartWidth: number
private cartHeight: number
private length: number
private poleMoment: number
private forceMag: number
private tau: number

private xThreshold: number
private thetaThreshold: number

private x: number = 0
private xDot: number = 0
private theta: number = 0
private thetaDot: number = 0

/**
* Constructor of CartPole.
*/
constructor() {
// Constants that characterize the system.
this.gravity = 9.8
this.massCart = 1.0
this.massPole = 0.1
this.totalMass = this.massCart + this.massPole
this.cartWidth = 0.2
this.cartHeight = 0.1
this.length = 0.5
this.poleMoment = this.massPole * this.length
this.forceMag = 10.0
this.tau = 0.02 // Seconds between state updates.

// Threshold values, beyond which a simulation will be marked as failed.
this.xThreshold = 2.4
this.thetaThreshold = (12 / 360) * 2 * Math.PI

this.reset()
}

/**
* Get current state as a tf.Tensor of shape [1, 4].
*/
getStateTensor() {
return [this.x, this.xDot, this.theta, this.thetaDot]
}

/**
* Update the cart-pole system using an action.
* @param {number} action Only the sign of `action` matters.
* A value > 0 leads to a rightward force of a fixed magnitude.
* A value <= 0 leads to a leftward force of the same fixed magnitude.
*/
step(action: number) {
const force = action > 0 ? this.forceMag : -this.forceMag

const cosTheta = Math.cos(this.theta)
const sinTheta = Math.sin(this.theta)

const temp =
(force + this.poleMoment * this.thetaDot * this.thetaDot * sinTheta) / this.totalMass
const thetaAcc =
(this.gravity * sinTheta - cosTheta * temp) /
(this.length * (4 / 3 - (this.massPole * cosTheta * cosTheta) / this.totalMass))
const xAcc = temp - (this.poleMoment * thetaAcc * cosTheta) / this.totalMass

// Update the four state variables, using Euler's method.
this.x += this.tau * this.xDot
this.xDot += this.tau * xAcc
this.theta += this.tau * this.thetaDot
this.thetaDot += this.tau * thetaAcc

const reward = this.isDone() ? -100 : 1
return [this.getStateTensor(), reward, this.isDone()]
}

/**
* Set the state of the cart-pole system randomly.
*/
reset() {
// The control-theory state variables of the cart-pole system.
// Cart position, meters.
this.x = Math.random() - 0.5
// Cart velocity.
this.xDot = (Math.random() - 0.5) * 1
// Pole angle, radians.
this.theta = (Math.random() - 0.5) * 2 * ((6 / 360) * 2 * Math.PI)
// Pole angle velocity.
this.thetaDot = (Math.random() - 0.5) * 0.5

return this.getStateTensor()
}

/**
* Determine whether this simulation is done.
*
* A simulation is done when `x` (position of the cart) goes out of bound
* or when `theta` (angle of the pole) goes out of bound.
*
* @returns {bool} Whether the simulation is done.
*/
isDone() {
return (
this.x < -this.xThreshold ||
this.x > this.xThreshold ||
this.theta < -this.thetaThreshold ||
this.theta > this.thetaThreshold
)
}
}

const tf = require("@tensorflow/tfjs-node")
const env = new CartPole()

const ppo = new PPO(env, {
nSteps: 1024,
nEpochs: 50,
verbose: 1,
netArch: [32],
})

const x = sac.act([0, 0, 0, 0, 0, 0, 0, 0])
x.print()
function possibleLifetime() {
env.reset()

let t = 0

while (!env.isDone() && t < 1000) {
const action = ppo.predict(tf.tensor([env.getStateTensor()]), true).arraySync()[0][0]
env.step(action)
t++
}

return t
}

console.log("Lifetime before training:", possibleLifetime())
;(async () => {
await ppo.learn({
totalTimesteps: 20000,
})
})().then(() => {
console.log("Lifetime after training:", possibleLifetime())
})

/*
import { WorldModel } from "runtime/proto/world"
Expand Down
4 changes: 3 additions & 1 deletion packages/learning/src/soft-actor-critic/actor.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import * as tf from "@tensorflow/tfjs"
import * as tf from "@tensorflow/tfjs-node-gpu"
import { GaussianLikelihood } from "./gaussian-likelihood"
import { MlpSpecification, mlp } from "./mlp"

Expand All @@ -20,6 +20,8 @@ export class Actor extends tf.layers.Layer {
sizes: [observationSize, ...mlpSpec.sizes],
})

this.net.predict

this.meanLayer = tf.layers.dense({
units: actionSize,
})
Expand Down
4 changes: 2 additions & 2 deletions packages/learning/src/soft-actor-critic/critic.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import * as tf from "@tensorflow/tfjs"
import * as tf from "@tensorflow/tfjs-node-gpu"
import { MlpSpecification, mlp } from "./mlp"

export class Critic extends tf.layers.Layer {
Expand All @@ -9,7 +9,7 @@ export class Critic extends tf.layers.Layer {

this.q = mlp({
...mlpSpec,
sizes: [observationSize + actionSize, ...mlpSpec.sizes],
sizes: [observationSize + actionSize, ...mlpSpec.sizes, 1],
outputActivation: undefined,
})
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import * as tf from "@tensorflow/tfjs"
import * as tf from "@tensorflow/tfjs-node-gpu"

export class GaussianLikelihood extends tf.layers.Layer {
computeOutputShape(inputShape: tf.Shape[]): tf.Shape | tf.Shape[] {
Expand Down
2 changes: 1 addition & 1 deletion packages/learning/src/soft-actor-critic/mlp.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as tf from "@tensorflow/tfjs"
import { ActivationIdentifier } from "@tensorflow/tfjs-layers/dist/keras_format/activation_config"
import * as tf from "@tensorflow/tfjs-node-gpu"

export interface MlpSpecification {
sizes: number[]
Expand Down
10 changes: 9 additions & 1 deletion packages/learning/src/soft-actor-critic/replay-buffer.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import * as tf from "@tensorflow/tfjs"
import * as tf from "@tensorflow/tfjs-node-gpu"

export interface Experience {
observation: number[]
Expand Down Expand Up @@ -32,6 +32,10 @@ export class ReplayBuffer {
private observationSize: number,
private actionSize: number,
) {
if (batchSize > capacity) {
throw new Error("Batch size must be less than or equal to capacity")
}

this.tensorObservation = tf.buffer([batchSize, observationSize], "float32")
this.tensorAction = tf.buffer([batchSize, actionSize], "float32")
this.tensorReward = tf.buffer([batchSize], "float32")
Expand All @@ -49,6 +53,10 @@ export class ReplayBuffer {
}

sample(): ExperienceTensor {
if (this.buffer.length < this.batchSize) {
throw new Error("Buffer does not have enough experiences")
}

const indices = tf.util.createShuffledIndices(this.buffer.length)

for (let i = 0; i < this.batchSize; i++) {
Expand Down
Loading

0 comments on commit 782fee8

Please sign in to comment.