Skip to content

Commit

Permalink
Finish ppo with convolution prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
phisn committed May 7, 2024
1 parent 9a0b2fb commit a15dfff
Show file tree
Hide file tree
Showing 22 changed files with 620 additions and 529 deletions.
246 changes: 246 additions & 0 deletions packages/learning-gym/src/game-environment.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import RAPIER from "custom-rapier2d-node/rapier"
import * as gl from "gl"
import { PNG } from "pngjs"
import { EntityWith, MessageCollector } from "runtime-framework"
import { WorldModel } from "runtime/proto/world"
import { LevelCapturedMessage } from "runtime/src/core/level-capture/level-captured-message"
import { RuntimeComponents } from "runtime/src/core/runtime-components"
import { RuntimeSystemContext } from "runtime/src/core/runtime-system-stack"
import { Runtime, newRuntime } from "runtime/src/runtime"
import * as THREE from "three"
import { GameAgentWrapper } from "web-game/src/game/game-agent-wrapper"
import { Reward, RewardFactory } from "../../web-game/src/game/reward/default-reward"

export interface GameEnvironmentConfig {
grayScale: boolean
size: number
pixelsPerUnit: number
stepsPerFrame: number
}

export class GameEnvironment {
private observationImageBuffer: Buffer
private observationFeatureBuffer: Buffer
private imageBuffer: Buffer
private imageChannels: number

private runtime!: Runtime
private reward!: Reward
private game!: GameAgentWrapper
private renderer: THREE.WebGLRenderer

private rotation!: number
private rocket!: EntityWith<RuntimeComponents, "rocket" | "rigidBody">
private targetFlag!: EntityWith<RuntimeComponents, "level">
private capturedCollector!: MessageCollector<LevelCapturedMessage>

private png: PNG

constructor(
private world: WorldModel,
private gamemode: string[],
private config: GameEnvironmentConfig,
private rewardFactory: RewardFactory,
) {
this.imageChannels = config.grayScale ? 1 : 3

// features (4 bytes)
// - velocity x
// - velocity y
// - rotation
// - distance to flag x
// - distance to flag y
// - flag in capture
this.observationFeatureBuffer = Buffer.alloc(4 * 6)

// image (3 channels)
this.observationImageBuffer = Buffer.alloc(
config.size * config.size * (config.grayScale ? 1 : 3),
)

// source image has additionally alpha channel
this.imageBuffer = Buffer.alloc(config.size * config.size * 4)

this.png = new PNG({
width: config.size,
height: config.size,
})

const canvas = {
width: config.size,
height: config.size,
addEventListener: () => {},
removeEventListener: () => {},
}

this.renderer = new THREE.WebGLRenderer({
canvas: canvas as any,
antialias: false,
powerPreference: "high-performance",
context: gl.default(config.size, config.size, {
preserveDrawingBuffer: true,
}),
depth: false,
})

const renderTarget = new THREE.WebGLRenderTarget(config.size, config.size)
this.renderer.setRenderTarget(renderTarget)

this.reset()
}

reset(): [Buffer, Buffer] {
this.runtime = newRuntime(
RAPIER as any,
this.world,
this.gamemode[Math.floor(Math.random() * this.gamemode.length)],
)

this.game = new GameAgentWrapper(
this.runtime,
new THREE.Scene() as any,
this.config.grayScale,
(0.5 * this.config.size) / this.config.pixelsPerUnit,
)

this.rocket = this.runtime.factoryContext.store.find("rocket", "rigidBody")[0]
this.capturedCollector = this.runtime.factoryContext.messageStore.collect("levelCaptured")
this.targetFlag = nextFlag(this.runtime, this.rocket)
this.rotation = 0
this.reward = this.rewardFactory(this.runtime)

this.extractPixelsToObservationBuffer()
this.prepareFeatureBuffer()

return [this.observationImageBuffer, this.observationFeatureBuffer]
}

step(action: Buffer): [number, boolean, Buffer, Buffer] {
const input = this.stepWithActionToInput(action.readInt8(0))

const [reward, done] = this.reward.next(() => {
for (let i = 0; i < this.config.stepsPerFrame; ++i) {
this.game.step(input)
}
})

this.renderer.render(this.game.sceneModule.getScene() as any, this.game.camera as any)

this.extractPixelsToObservationBuffer()
this.prepareFeatureBuffer()

return [reward, done, this.observationImageBuffer, this.observationFeatureBuffer]
}

stepWithActionToInput(action: number): RuntimeSystemContext {
switch (action) {
case 0:
return { thrust: false, rotation: this.rotation }
case 1:
this.rotation += 0.1
return { thrust: false, rotation: this.rotation }
case 2:
this.rotation -= 0.1
return { thrust: false, rotation: this.rotation }
case 3:
return { thrust: true, rotation: this.rotation }
case 4:
this.rotation += 0.1
return { thrust: true, rotation: this.rotation }
case 5:
this.rotation -= 0.1
return { thrust: true, rotation: this.rotation }
default:
throw new Error(`Invalid action: ${action}`)
}
}

extractPixelsToObservationBuffer() {
this.renderer
.getContext()
.readPixels(
0,
0,
this.renderer.getContext().drawingBufferWidth,
this.renderer.getContext().drawingBufferHeight,
this.renderer.getContext().RGBA,
this.renderer.getContext().UNSIGNED_BYTE,
this.imageBuffer,
)

// The framebuffer coordinate space has (0, 0) in the bottom left, whereas images usually
// have (0, 0) at the top left. Vertical flipping follows:
for (let row = 0; row < this.config.size; row += 1) {
for (let column = 0; column < this.config.size; column++) {
const index = ((this.config.size - row - 1) * this.config.size + column) * 4

if (this.config.grayScale) {
// we use a cheap grayscale conversion
const value =
this.imageBuffer[index] |
this.imageBuffer[index + 1] |
this.imageBuffer[index + 2]

this.observationImageBuffer[row * this.config.size + column] = value
} else {
const targetIndex = (row * this.config.size + column) * 3

this.observationImageBuffer[targetIndex] = this.imageBuffer[index]
this.observationImageBuffer[targetIndex + 1] = this.imageBuffer[index + 1]
this.observationImageBuffer[targetIndex + 2] = this.imageBuffer[index + 2]
}
}
}
}

prepareFeatureBuffer() {
for (const message of this.capturedCollector) {
this.targetFlag = nextFlag(this.runtime, this.rocket)
}

const dx =
this.rocket.components.rigidBody.translation().x -
this.targetFlag.components.level.flag.x
const dy =
this.rocket.components.rigidBody.translation().y -
this.targetFlag.components.level.flag.y

const inCapture = this.targetFlag.components.level.inCapture

this.observationFeatureBuffer.writeFloatLE(this.rocket.components.rigidBody.linvel().x, 0)
this.observationFeatureBuffer.writeFloatLE(this.rocket.components.rigidBody.linvel().y, 4)
this.observationFeatureBuffer.writeFloatLE(this.rotation, 8)
this.observationFeatureBuffer.writeFloatLE(dx, 12)
this.observationFeatureBuffer.writeFloatLE(dy, 16)
this.observationFeatureBuffer.writeFloatLE(inCapture ? 1 : 0, 20)
}

generatePng(): Buffer {
this.png.data.set(this.observationImageBuffer)

return PNG.sync.write(this.png, {
inputColorType: this.config.grayScale ? 0 : 2,
inputHasAlpha: false,
})
}
}

function nextFlag(runtime: Runtime, rocket: EntityWith<RuntimeComponents, "rocket" | "rigidBody">) {
const distanceToFlag = (flagEntity: EntityWith<RuntimeComponents, "level">) => {
const dx = rocket.components.rigidBody.translation().x - flagEntity.components.level.flag.x
const dy = rocket.components.rigidBody.translation().y - flagEntity.components.level.flag.y
return Math.sqrt(dx * dx + dy * dy)
}

const nextLevel = runtime.factoryContext.store
.find("level")
.filter(level => !level.components.level.captured)
.map(level => [level, distanceToFlag(level)] as const)
.reduce(([minLevel, minDistance], [level, distance]) =>
distance < minDistance ? [level, distance] : [minLevel, minDistance],
)[0]

return nextLevel
}

global.navigator = { userAgent: "node" } as any
Loading

0 comments on commit a15dfff

Please sign in to comment.