From 450e7a60f250f5ab5dec2799230483f8ad758b04 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 16 May 2024 19:49:34 +0000 Subject: [PATCH] Basic support for calling Pathfinder --- docker/Dockerfile | 2 +- docker/make/local | 5 +- gui/src/app/StanSampler/StanSampler.ts | 2 +- gui/src/app/tinystan/Worker.ts | 23 +++- gui/src/app/tinystan/index.ts | 183 ++++++++++++++++++++++++- 5 files changed, 199 insertions(+), 16 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index cf2c31c1..fd1fc660 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -42,7 +42,7 @@ COPY make/js /app/tinystan/make/js # Build a test model RUN cd tinystan && \ echo 'include make/js' >> Makefile && \ - emmake make test_models/bernoulli/bernoulli.js -j2 && \ + emmake make test_models/bernoulli/bernoulli.js -j4 && \ emstrip test_models/bernoulli/bernoulli.wasm RUN pip install fastapi diff --git a/docker/make/local b/docker/make/local index 23cdc483..d215153b 100644 --- a/docker/make/local +++ b/docker/make/local @@ -11,11 +11,10 @@ LDLIBS_TBB ?= -ltbb # could also uses -fexceptions which is more compatible, but slower CXXFLAGS+=-fwasm-exceptions -CXXFLAGS+=-g -LDFLAGS+=-sMODULARIZE -sEXPORT_NAME=createModule -sEXPORT_ES6 -sENVIRONMENT=web +LDFLAGS+=-sMODULARIZE -sEXPORT_NAME=createModule -sEXPORT_ES6 -sENVIRONMENT=web -sINCOMING_MODULE_JS_API=print,printErr LDFLAGS+=-sEXIT_RUNTIME=1 -sALLOW_MEMORY_GROWTH=1 # Functions we want. Can add more, with a prepended _, from tinystan.h -EXPORTS=_malloc,_free,_tinystan_api_version,_tinystan_create_model,_tinystan_destroy_error,_tinystan_destroy_model,_tinystan_get_error_message,_tinystan_get_error_type,_tinystan_model_num_free_params,_tinystan_model_param_names,_tinystan_sample,_tinystan_separator_char,_tinystan_stan_version +EXPORTS=_malloc,_free,_tinystan_api_version,_tinystan_create_model,_tinystan_destroy_error,_tinystan_destroy_model,_tinystan_get_error_message,_tinystan_get_error_type,_tinystan_model_num_free_params,_tinystan_model_param_names,_tinystan_sample,_tinystan_pathfinder,_tinystan_separator_char,_tinystan_stan_version LDFLAGS+=-sEXPORTED_FUNCTIONS=$(EXPORTS) -sEXPORTED_RUNTIME_METHODS=stringToUTF8,getValue,UTF8ToString,lengthBytesUTF8 diff --git a/gui/src/app/StanSampler/StanSampler.ts b/gui/src/app/StanSampler/StanSampler.ts index e679c8cb..b4d3037d 100644 --- a/gui/src/app/StanSampler/StanSampler.ts +++ b/gui/src/app/StanSampler/StanSampler.ts @@ -41,7 +41,7 @@ class StanSampler { this.#onStatusChangedCallbacks.forEach(cb => cb()) break; } - case Replies.SampleReturn: { + case Replies.StanReturn: { if (e.data.error) { this.#errorMessage = e.data.error; this.#status = 'failed'; diff --git a/gui/src/app/tinystan/Worker.ts b/gui/src/app/tinystan/Worker.ts index c87706eb..80b8aba7 100644 --- a/gui/src/app/tinystan/Worker.ts +++ b/gui/src/app/tinystan/Worker.ts @@ -3,11 +3,12 @@ import StanModel from "."; export enum Requests { Load = "load", Sample = "sample", + Pathfinder = "pathfinder", } export enum Replies { ModelLoaded = "modelLoaded", - SampleReturn = "sampleReturn", + StanReturn = "stanReturn", Progress = "progress", } @@ -62,15 +63,29 @@ onmessage = function (e) { } case Requests.Sample: { if (!model) { - postMessage({ purpose: Replies.SampleReturn, error: "Model not loaded yet!" }) + postMessage({ purpose: Replies.StanReturn, error: "Model not loaded yet!" }) return; } try { const { paramNames, draws } = model.sample(e.data.sampleConfig); // TODO? use an ArrayBuffer so we can transfer without serialization cost - postMessage({ purpose: Replies.SampleReturn, draws, paramNames, error: null }); + postMessage({ purpose: Replies.StanReturn, draws, paramNames, error: null }); } catch (e: any) { - postMessage({ purpose: Replies.SampleReturn, error: e.toString() }) + postMessage({ purpose: Replies.StanReturn, error: e.toString() }) + } + break; + } + case Requests.Pathfinder: { + if (!model) { + postMessage({ purpose: Replies.StanReturn, error: "Model not loaded yet!" }) + return; + } + try { + const { draws, paramNames } = model.pathfinder(e.data.pathfinderConfig); + // TODO? use an ArrayBuffer so we can transfer without serialization cost + postMessage({ purpose: Replies.StanReturn, draws, paramNames, error: null }); + } catch (e: any) { + postMessage({ purpose: Replies.StanReturn, error: e.toString() }) } break; } diff --git a/gui/src/app/tinystan/index.ts b/gui/src/app/tinystan/index.ts index ccf19b6b..e2a6dc80 100644 --- a/gui/src/app/tinystan/index.ts +++ b/gui/src/app/tinystan/index.ts @@ -24,6 +24,12 @@ interface WasmModule { term_buffer: number, window: number, save_warmup: number, stepsize: number, stepsize_jitter: number, max_depth: number, refresh: number, num_threads: number, out: ptr, out_size: number, metric_out: ptr, err_ptr: ptr): number; + // prettier-ignore + _tinystan_pathfinder(model: model_ptr, num_paths: number, inits: cstr, seed: number, id: number, + init_radius: number, num_draws: number, max_history_size: number, init_alpha: number, tol_obj: number, + tol_rel_obj: number, tol_grad: number, tol_rel_grad: number, tol_param: number, num_iterations: number, + num_elbo_draws: number, num_multi_draws: number, calculate_lp: number, psis_resample: number, + refresh: number, num_threads: number, out: ptr, out_size: number, err_ptr: ptr): number; _tinystan_get_error_message(err_ptr: error_ptr): cstr; _tinystan_get_error_type(err_ptr: error_ptr): number; _tinystan_destroy_error(err_ptr: error_ptr): void; @@ -38,6 +44,8 @@ interface WasmModule { const NULL = 0 as ptr; +const PTR_SIZE = 4; + const HMC_SAMPLER_VARIABLES = [ "lp__", "accept_stat__", @@ -54,6 +62,8 @@ export enum HMCMetric { DIAGONAL = 2, } +const PATHFINDER_VARIABLES = ["lp_approx__", "lp__"]; + export type PrintCallback = (s: string) => void; export type StanDraws = { @@ -119,6 +129,59 @@ const defaultSamplerParams: SamplerParams = { num_threads: -1, }; +interface LBFGSConfig { + max_history_size: number; + init_alpha: number; + tol_obj: number; + tol_rel_obj: number; + tol_grad: number; + tol_rel_grad: number; + tol_param: number; + num_iterations: number; +} + +interface PathfinderUniqueParams { + data: string | StanVariableInputs; + num_paths: number; + inits: string | StanVariableInputs | string[] | StanVariableInputs[]; + seed: number | null; + id: number; + init_radius: number; + num_draws: number; + num_elbo_draws: number; + num_multi_draws: number; + calculate_lp: boolean; + psis_resample: boolean; + refresh: number; + num_threads: number; +} + +export type PathfinderParams = LBFGSConfig & PathfinderUniqueParams; + +const defaultPathfinderParams: PathfinderParams = { + data: "", + num_paths: 4, + inits: "", + seed: null, + id: 1, + init_radius: 2.0, + num_draws: 1000, + max_history_size: 5, + init_alpha: 0.001, + tol_obj: 1e-12, + tol_rel_obj: 1e4, + tol_grad: 1e-8, + tol_rel_grad: 1e7, + tol_param: 1e-8, + num_iterations: 1000, + num_elbo_draws: 25, + num_multi_draws: 1000, + calculate_lp: true, + psis_resample: true, + refresh: 100, + num_threads: -1, +}; + export default class StanModel { private m: WasmModule; private printCallback: PrintCallback | null; @@ -190,7 +253,7 @@ export default class StanModel { f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T, ): T { const data_ptr = this.encodeString(string_safe_jsonify(data)); - const err_ptr = this.m._malloc(4); + const err_ptr = this.m._malloc(PTR_SIZE); const model = this.m._tinystan_create_model(data_ptr, seed, err_ptr); this.m._free(data_ptr); @@ -248,10 +311,7 @@ export default class StanModel { throw new Error("num_samples must be at least 1"); } - let seed_ = seed; - if (seed_ === null) { - seed_ = Math.floor(Math.random() * (2 ^ 32)); - } + const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32)); return this.withModel(data, seed_, (model, deferredFree) => { // Get the parameter names @@ -264,7 +324,7 @@ export default class StanModel { const free_params = this.m._tinystan_model_num_free_params(model); if (free_params === 0) { - throw new Error("No parameters to sample"); + throw new Error("Model has no parameters to sample."); } // TODO: allow init_inv_metric to be specified @@ -297,7 +357,7 @@ export default class StanModel { const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT); deferredFree(out_ptr); - const err_ptr = this.m._malloc(4); + const err_ptr = this.m._malloc(PTR_SIZE); deferredFree(err_ptr); // Sample from the model @@ -386,6 +446,115 @@ export default class StanModel { }); } + public pathfinder(p: Partial): StanDraws { + const { + data, + num_paths, + inits, + seed, + id, + init_radius, + num_draws, + max_history_size, + init_alpha, + tol_obj, + tol_rel_obj, + tol_grad, + tol_rel_grad, + tol_param, + num_iterations, + num_elbo_draws, + num_multi_draws, + calculate_lp, + psis_resample, + refresh, + num_threads, + } = { ...defaultPathfinderParams, ...p }; + + if (num_paths < 1) { + throw new Error("num_paths must be at least 1"); + } + if (num_draws < 1) { + throw new Error("num_draws must be at least 1"); + } + if (num_multi_draws < 1) { + throw new Error("num_multi_draws must be at least 1"); + } + + const output_rows = + calculate_lp && psis_resample ? num_multi_draws : num_draws * num_paths; + + const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32)); + + return this.withModel(data, seed_, (model, deferredFree) => { + const rawParamNames = this.m.UTF8ToString( + this.m._tinystan_model_param_names(model), + ); + const paramNames = PATHFINDER_VARIABLES.concat(rawParamNames.split(",")); + + const n_params = paramNames.length; + + const free_params = this.m._tinystan_model_num_free_params(model); + if (free_params === 0) { + throw new Error("Model has no parameters."); + } + + const inits_ptr = this.encodeInits(inits); + deferredFree(inits_ptr); + + const n_out = output_rows * n_params; + const out = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT); + deferredFree(out); + const err_ptr = this.m._malloc(PTR_SIZE); + deferredFree(err_ptr); + + const result = this.m._tinystan_pathfinder( + model, + num_paths, + inits_ptr, + seed_ || 0, + id, + init_radius, + num_draws, + max_history_size, + init_alpha, + tol_obj, + tol_rel_obj, + tol_grad, + tol_rel_grad, + tol_param, + num_iterations, + num_elbo_draws, + num_multi_draws, + calculate_lp ? 1 : 0, + psis_resample ? 1 : 0, + refresh, + num_threads, + out, + n_out, + err_ptr, + ); + + if (result != 0) { + this.handleError(err_ptr); + } + + const out_buffer = this.m.HEAPF64.subarray( + out / Float64Array.BYTES_PER_ELEMENT, + out / Float64Array.BYTES_PER_ELEMENT + n_out, + ); + + const draws: number[][] = Array.from({ length: n_params }, (_, i) => + Array.from( + { length: output_rows }, + (_, j) => out_buffer[i + n_params * j], + ), + ); + + return { paramNames, draws }; + }); + } + public stanVersion(): string { const major = this.m._malloc(4); const minor = this.m._malloc(4);