From 00a0b3b2ea5ef82e7dbad3204aef656219f1c0f4 Mon Sep 17 00:00:00 2001 From: Marcus Ramse Date: Sun, 28 Apr 2024 18:46:11 +0000 Subject: [PATCH] netplay: patch export globals and sync them --- runtimes/web/src/runtime.ts | 4 +- runtimes/web/src/state.ts | 56 ++++++++++++---- runtimes/web/src/wasm-patch.ts | 117 +++++++++++++++++++++++++++++++++ 3 files changed, 165 insertions(+), 12 deletions(-) create mode 100644 runtimes/web/src/wasm-patch.ts diff --git a/runtimes/web/src/runtime.ts b/runtimes/web/src/runtime.ts index 043c9ec3..515cc151 100644 --- a/runtimes/web/src/runtime.ts +++ b/runtimes/web/src/runtime.ts @@ -4,6 +4,7 @@ import { APU } from "./apu"; import { Framebuffer } from "./framebuffer"; import { WebGLCompositor } from "./compositor"; import * as devkit from "./devkit"; +import { wasmPatchExportGlobals } from "./wasm-patch"; export class Runtime { canvas: HTMLCanvasElement; @@ -158,7 +159,8 @@ export class Runtime { }; await this.bluescreenOnError(async () => { - const module = await WebAssembly.instantiate(wasmBuffer, { env }); + const patchedWasmBuffer = wasmPatchExportGlobals(wasmBuffer); + const module = await WebAssembly.instantiate(patchedWasmBuffer, { env }); this.wasm = module.instance; // Call the WASI _start/_initialize function (different from WASM-4's start callback!) diff --git a/runtimes/web/src/state.ts b/runtimes/web/src/state.ts index 286f9979..91bfd96d 100644 --- a/runtimes/web/src/state.ts +++ b/runtimes/web/src/state.ts @@ -3,6 +3,7 @@ import { Runtime } from "./runtime"; export class State { memory: ArrayBuffer; + globals: {[name: string]: string}; diskSize: number; diskBuffer: ArrayBuffer; @@ -11,6 +12,7 @@ export class State { constructor () { this.memory = new ArrayBuffer(1 << 16); + this.globals = {}; this.diskBuffer = new ArrayBuffer(constants.STORAGE_SIZE); this.diskSize = 0; } @@ -18,6 +20,14 @@ export class State { read (runtime: Runtime) { new Uint8Array(this.memory).set(new Uint8Array(runtime.memory.buffer)); + this.globals = {}; + for (const exName in runtime.wasm!.exports) { + const exInst = runtime.wasm!.exports[exName] + if (exInst instanceof WebAssembly.Global) { + this.globals[exName] = exInst.value.toString(); // believe it or not, `toString()` seems to be safe + } + } + this.diskSize = runtime.diskSize; new Uint8Array(this.diskBuffer).set(new Uint8Array(runtime.diskBuffer, 0, runtime.diskSize)); } @@ -25,32 +35,56 @@ export class State { write (runtime: Runtime) { new Uint8Array(runtime.memory.buffer).set(new Uint8Array(this.memory)); + for (const exName in runtime.wasm!.exports) { + const exInst = runtime.wasm!.exports[exName] + if (exInst instanceof WebAssembly.Global && exName in this.globals) { + exInst.value = this.globals[exName]; + } + } + runtime.diskSize = this.diskSize; new Uint8Array(runtime.diskBuffer).set(new Uint8Array(this.diskBuffer, 0, this.diskSize)); } - toBytes (dest?: Uint8Array): Uint8Array { - if (!dest) { - dest = new Uint8Array((1<<16) + 4 + this.diskSize); - } + toBytes (): Uint8Array { + // Serialize globals + const globalBytes = new TextEncoder().encode(JSON.stringify(this.globals)); + // Perpare output buffer + const dest = new Uint8Array((1<<16) + 8 + globalBytes.byteLength + this.diskSize); + const dataView = new DataView(dest.buffer, dest.byteOffset, dest.byteLength); + + // Write memory dest.set(new Uint8Array(this.memory), 0); + let offset = 1<<16; - const dataView = new DataView(dest.buffer, dest.byteOffset, dest.byteLength); - dataView.setUint32(1<<16, this.diskSize); + // Write globals + dataView.setUint32(offset, globalBytes.byteLength); + dest.set(globalBytes, offset + 4); + offset += 4 + globalBytes.byteLength; - dest.set(new Uint8Array(this.diskBuffer, 0, this.diskSize), (1<<16) + 4); + // Write disk + dataView.setUint32(offset, this.diskSize); + dest.set(new Uint8Array(this.diskBuffer, 0, this.diskSize), offset + 4); return dest; } fromBytes (src: Uint8Array) { + const dataView = new DataView(src.buffer, src.byteOffset, src.byteLength); + + // Read memory new Uint8Array(this.memory).set(src.subarray(0, 1<<16)); + let offset = 1<<16; - const dataView = new DataView(src.buffer, src.byteOffset, src.byteLength); - this.diskSize = dataView.getUint32(1<<16); + // Read globals + const globalBytesSize = dataView.getUint32(offset); + const globalBytes = src.slice(offset + 4, offset + 4 + globalBytesSize) + this.globals = JSON.parse(new TextDecoder().decode(globalBytes)); + offset += 4 + globalBytesSize; - const offset = (1<<16) + 4; - new Uint8Array(this.diskBuffer).set(src.subarray(offset, offset + this.diskSize)); + // Read disk + this.diskSize = dataView.getUint32(offset); + new Uint8Array(this.diskBuffer).set(src.subarray(offset + 4, offset + 4 + this.diskSize)); } } diff --git a/runtimes/web/src/wasm-patch.ts b/runtimes/web/src/wasm-patch.ts new file mode 100644 index 00000000..e29e9d26 --- /dev/null +++ b/runtimes/web/src/wasm-patch.ts @@ -0,0 +1,117 @@ +type PatchSlice = [0, number, number] | [1, Uint8Array]; + +const WASM_SECTION_GLOBAL = 6; +const WASM_SECTION_EXPORTS = 7; +const WASM_EXPORT_GLOBAL = 3; + +export function wasmPatchExportGlobals(data: Uint8Array): Uint8Array { + // Make sure binary is valid WASM + const view = new DataView(data.buffer); + const magic = view.getUint32(0); + const version = view.getUint32(4, true); + if (magic !== 0x0061736d || version !== 0x1) { + throw new Error('Invalid WASM binary'); + } + let dataI = 8; + + // Iterate all sections and begin patching + const outputSlices: PatchSlice[] = []; + let globalCount = 0, lastCut = 0, secSize; + while (dataI < data.byteLength) { + const secType = data[dataI]; + const secRawStartI = dataI; + [secSize, dataI] = uleb128Decode(data, dataI + 1); + const secRawEndI = dataI + secSize; + let secI = dataI; + dataI += secSize; + + if (secType === WASM_SECTION_GLOBAL) { + globalCount += uleb128Decode(data, secI)[0]; + } else if (secType === WASM_SECTION_EXPORTS) { + // Push everything up until this section into output and "ignore" this section + outputSlices.push([0, lastCut, secRawStartI]); + lastCut = secRawEndI; + + // Iterate all current exports and see which globals are missing + const exportedGlobals = new Set(); + const exportSlices: PatchSlice[] = []; + let exCount, exNameLen, exIdx; + [exCount, secI] = uleb128Decode(data, secI); + for (let exI = 0; exI < exCount; exI++) { + const exStart = secI; + [exNameLen, secI] = uleb128Decode(data, secI); + secI += exNameLen; + const exType = data[secI++]; + [exIdx, secI] = uleb128Decode(data, secI); + if (exType === WASM_EXPORT_GLOBAL) { + exportedGlobals.add(exIdx); + } + exportSlices.push([0, exStart, secI]); + } + + // Add exports for missing globals + for (let glI = 0; glI < globalCount; glI++) { + if (!exportedGlobals.has(glI)) { + const nameBytes = new TextEncoder().encode(`__global_${glI}`); + const nameLenBytes = uleb128Encode(nameBytes.length); + const exIdxBytes = uleb128Encode(glI); + const exBytes = new Uint8Array(nameBytes.length + nameLenBytes.length + exIdxBytes.length + 1); + exBytes.set(nameLenBytes); + exBytes.set(nameBytes, nameLenBytes.length); + exBytes[nameLenBytes.length + nameBytes.length] = WASM_EXPORT_GLOBAL; + exBytes.set(exIdxBytes, nameLenBytes.length + nameBytes.length + 1); + exportSlices.push([1, exBytes]); + } + } + + // Push new export section + const newExCountBytes = uleb128Encode(exportSlices.length); + const newExListBytes = joinPatchSlices(data, exportSlices); + const newExSize = uleb128Encode(newExCountBytes.length + newExListBytes.length); + outputSlices.push([1, new Uint8Array([WASM_SECTION_EXPORTS])]); + outputSlices.push([1, newExSize]); + outputSlices.push([1, newExCountBytes]); + outputSlices.push([1, newExListBytes]); + } + } + // Push leftovers into output + outputSlices.push([0, lastCut, dataI]); + + return joinPatchSlices(data, outputSlices); +} + +function joinPatchSlices(source: Uint8Array, slices: PatchSlice[]): Uint8Array { + const totalSize = slices.reduce((a, v) => a + (v[0] === 0 ? v[2] - v[1] : v[1].length), 0); + const outBuf = new Uint8Array(totalSize); + let outBufI = 0; + for (const slice of slices) { + if (slice[0] === 0) { + outBuf.set(source.slice(slice[1], slice[2]), outBufI); + outBufI += slice[2] - slice[1]; + } else { + outBuf.set(slice[1], outBufI); + outBufI += slice[1].length; + } + } + return outBuf; +} + +function uleb128Encode(num: number): Uint8Array { + const output: number[] = []; + do { + const low = num & 0x7f; + num >>= 7; + output.push(num ? (low | 0x80) : low); + } while (num); + return new Uint8Array(output); +} + +function uleb128Decode(view: Uint8Array, offset: number = 0): [number, number] { + let byte = 0, result = 0, shift = 0; + do { + byte = view[offset++]; + result |= (byte & 0x7f) << shift; + shift += 7; + } while (byte & 0x80); + return [result, offset]; +}