Skip to content

Commit

Permalink
Continue refactoring ppo-tfjs
Browse files Browse the repository at this point in the history
  • Loading branch information
phisn committed Apr 30, 2024
1 parent cbbf808 commit 23991e7
Show file tree
Hide file tree
Showing 2 changed files with 210 additions and 11 deletions.
2 changes: 1 addition & 1 deletion packages/learning/src/ppo/base-ppo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class PPO {
this.lastObservation = null

// Initialize buffer
this.buffer = new ReplayBuffer(config)
this.buffer = new Buffer(config)

// Initialize models for actor and critic
this.actor = this.createActor()
Expand Down
219 changes: 209 additions & 10 deletions packages/learning/src/ppo/ppo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,18 @@ class ReplayBuffer {
interface DiscreteSpace {
class: "Discrete"
dtype?: "int32"
n: number

len: number
}

interface BoxSpace {
class: "Box"
dtype?: "float32"
shape: number[]
low: number[]
high: number[]

low: number
high: number

len: number
}

type Space = DiscreteSpace | BoxSpace
Expand All @@ -134,10 +137,36 @@ interface PPOConfig {
clipRatio: number
targetKL: number

observationSpace: Space
observationDimension: number
actionSpace: Space
}

interface Environment {
reset(): number[]
step(action: number | number[]): [number[], number, boolean]
}

const ppo = new PPO(
{} as PPOConfig,
{} as Space,
[
{
class: "Box",
len: 2,
low: [0, 0],
high: [1, 1],
},
{
class: "Discrete",
len: 2,
},
],
{} as tf.LayersModel,
{} as tf.LayersModel,
)

ppo.act([1, 2, 3])

class PPO {
private numTimeSteps: number
private lastObservation: number[]
Expand All @@ -154,6 +183,9 @@ class PPO {

constructor(
private config: PPOConfig,

private env: Environment,

private actorModel: tf.LayersModel,
private criticModel: tf.LayersModel,
) {
Expand All @@ -167,7 +199,7 @@ class PPO {
layers: [
actorModel,
tf.layers.dense({
units: config.actionSpace.n,
units: config.actionSpace.len,
}),
],
})
Expand All @@ -176,7 +208,7 @@ class PPO {
layers: [
actorModel,
tf.layers.dense({
units: config.actionSpace.shape[0],
units: config.actionSpace.len,
}),
],
})
Expand All @@ -195,14 +227,181 @@ class PPO {
})

if (config.actionSpace.class === "Box") {
this.logStd = tf.variable(tf.zeros([config.actionSpace.shape[0]]), true, "logStd")
this.logStd = tf.variable(tf.zeros([config.actionSpace.len]), true, "logStd")
}

this.optimizerPolicy = tf.train.adam(config.policyLearningRate)
this.optimizerValue = tf.train.adam(config.valueLearningRate)
}

act(observation: number[]) {}
act(observation: number[]): GetPPOSpaceType<ActionSpaces, "actionType"> {}

private collectRollouts() {
this.buffer.reset()

let sumReturn = 0
let sumReward = 0
let numEpisodes = 0

for (let step = 0; step < this.config.steps; ++step) {
tf.tidy(() => {
const observation = tf.tensor2d(this.lastObservation)

const [predictions, action, actionSynced] = this.sampleAction(observation)
const value = this.critic.predict(observation) as tf.Tensor1D

// TODO verify types
const logProbability = this.logProb(predictions as any, action as any)

const [nextObservation, reward, done] = this.env.step(actionSynced)

sumReturn += reward
sumReward += reward
})
}
}

private trainValue(observationBuffer: tf.Tensor2D, returnBuffer: tf.Tensor1D) {
const optimize = () => {
const valuesPredictions = this.critic.predict(observationBuffer) as tf.Tensor1D
return tf.losses.meanSquaredError(returnBuffer, valuesPredictions) as tf.Scalar
}

tf.tidy(() => {
const { grads } = this.optimizerValue.computeGradients(optimize)
this.optimizerValue.applyGradients(grads)
})
}

private trainPolicy(
observationBuffer: tf.Tensor2D,
actionBuffer: tf.Tensor2D,
logProbabilityBuffer: tf.Tensor1D,
advantageBuffer: tf.Tensor1D,
) {
const optimize = () => {
const predictions = this.actor.predict(observationBuffer) as tf.Tensor2D

const logProbDiff = tf.sub(
this.logProb(predictions, actionBuffer),
logProbabilityBuffer,
)

const ratio = tf.exp(logProbDiff)

const minAdvantage = tf.where(
tf.greater(advantageBuffer, 0),
tf.mul(tf.add(1, this.config.clipRatio), advantageBuffer),
tf.mul(tf.sub(1, this.config.clipRatio), advantageBuffer),
)

const policyLoss = tf.neg(
tf.mean(tf.minimum(tf.mul(ratio, advantageBuffer), minAdvantage)),
)

return policyLoss as tf.Scalar
}

return tf.tidy(() => {
const { grads } = this.optimizerPolicy.computeGradients(optimize)
this.optimizerPolicy.applyGradients(grads)

const kl = tf.mean(
tf.sub(
logProbabilityBuffer,
this.logProb(
this.actor.predict(observationBuffer) as tf.Tensor2D,
actionBuffer,
),
),
)

return kl.arraySync()
})
}

private logProb(predictions: tf.Tensor2D, actions: tf.Tensor2D) {
if (this.config.actionSpace.class === "Discrete") {
return this.logProbCategorical(predictions, actions)
} else if (this.config.actionSpace.class === "Box") {
return this.logProbNormal(predictions, actions)
} else {
throw new Error("Unsupported action space")
}
}

private logProbCategorical(predictions: tf.Tensor2D, actions: tf.Tensor2D) {
return tf.tidy(() => {
const numActions = predictions.shape[predictions.shape.length - 1]
const logprobabilitiesAll = tf.logSoftmax(predictions)

return tf.sum(
tf.mul(tf.oneHot(actions, numActions), logprobabilitiesAll),
logprobabilitiesAll.shape.length - 1,
)
})
}

private logProbNormal(predictions: tf.Tensor2D, actions: tf.Tensor2D) {
return tf.tidy(() => {
if (this.logStd === undefined) {
throw new Error("logStd is not initialized")
}

const scale = tf.exp(this.logStd)

const logUnnormalized = tf.mul(
-0.5,
tf.square(tf.sub(tf.div(actions, scale), tf.div(predictions, scale))),
)

private sampleAction(observation: tf.Tensor2D): [tf.Tensor2D, tf.Tensor2D] {}
const logNormalization = tf.add(tf.scalar(0.5 * Math.log(2.0 * Math.PI)), tf.log(scale))

return tf.sum(
tf.sub(logUnnormalized, logNormalization),
logUnnormalized.shape.length - 1,
)
})
}

private sampleAction(observation: tf.Tensor2D) {
return tf.tidy(() => {
const predictions = tf.squeeze(
this.actor.predict(observation) as tf.Tensor2D,
) as tf.Tensor1D

const actionSpace = this.config.actionSpace

if (actionSpace.class === "Discrete") {
const action = tf.squeeze(tf.multinomial(predictions, 1)) as tf.Scalar
const actionSynced = action.arraySync()

return [predictions, action, actionSynced] as const
} else if (actionSpace.class === "Box") {
if (this.logStd === undefined) {
throw new Error("logStd is not initialized")
}

const action = tf.add(
tf.mul(tf.randomNormal([actionSpace.len]), tf.exp(this.logStd)),
predictions,
) as tf.Tensor1D

const actionClipped = action.arraySync().map((x, i) => {
const low =
typeof actionSpace.low === "number" ? actionSpace.low : actionSpace.low[i]
const high =
typeof actionSpace.high === "number"
? actionSpace.high
: actionSpace.high[i]

return Math.min(Math.max(x, low), high)
})

return [predictions, action, actionClipped] as const
} else {
throw new Error("Unsupported action space")
}
})
}
}

0 comments on commit 23991e7

Please sign in to comment.