Skip to content

Commit

Permalink
Experiment with tfjs
Browse files Browse the repository at this point in the history
  • Loading branch information
phisn committed Apr 29, 2024
1 parent 782fee8 commit 0e74a48
Show file tree
Hide file tree
Showing 4 changed files with 595 additions and 27 deletions.
2 changes: 2 additions & 0 deletions packages/learning/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
"vite": "^5.1.6"
},
"dependencies": {
"@tensorflow/tfjs-backend-webgl": "^4.18.0",
"@tensorflow/tfjs-backend-webgpu": "^4.18.0",
"@tensorflow/tfjs-node": "^4.18.0",
"@tensorflow/tfjs-node-gpu": "^4.18.0",
"@types/prompts": "^2.4.9",
Expand Down
58 changes: 31 additions & 27 deletions packages/learning/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ const observations = [
[[1, -1, 1, -1, 1, -1, 1, -1], [1]],
]

const PPO = require("ppo-tfjs")

export class CartPole {
actionSpace = {
class: "Box",
Expand Down Expand Up @@ -165,39 +163,44 @@ export class CartPole {
}

const tf = require("@tensorflow/tfjs-node")
const env = new CartPole()
require("@tensorflow/tfjs-backend-webgpu")

const ppo = new PPO(env, {
nSteps: 1024,
nEpochs: 50,
verbose: 1,
netArch: [32],
})
tf.setBackend("tensorflow").then(() => {
const env = new CartPole()

function possibleLifetime() {
env.reset()
const PPO = require("./ppo/base-ppo.js")

let t = 0
const ppo = new PPO(env, {
nSteps: 1024,
nEpochs: 50,
verbose: 1,
netArch: [16],
})

while (!env.isDone() && t < 1000) {
const action = ppo.predict(tf.tensor([env.getStateTensor()]), true).arraySync()[0][0]
env.step(action)
t++
}
function possibleLifetime() {
env.reset()

return t
}
let t = 0

console.log("Lifetime before training:", possibleLifetime())
;(async () => {
await ppo.learn({
totalTimesteps: 20000,
while (!env.isDone() && t < 1000) {
const action = ppo.predict(tf.tensor([env.getStateTensor()]), true).arraySync()[0][0]
env.step(action)
t++
}

return t
}

console.log("Lifetime before training:", possibleLifetime())
;(async () => {
await ppo.learn({
totalTimesteps: 5000,
})
})().then(() => {
console.log("Lifetime after training:", possibleLifetime())
})
})().then(() => {
console.log("Lifetime after training:", possibleLifetime())
})

/*
/*
import { WorldModel } from "runtime/proto/world"
import { Game } from "./game/game"
import { GameLoop } from "./game/game-loop"
Expand Down Expand Up @@ -225,3 +228,4 @@ try {
console.error(e)
}
*/
})
Loading

0 comments on commit 0e74a48

Please sign in to comment.