Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for calling Pathfinder #19

Merged
merged 1 commit into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ COPY make/js /app/tinystan/make/js
# Build a test model
RUN cd tinystan && \
echo 'include make/js' >> Makefile && \
emmake make test_models/bernoulli/bernoulli.js -j2 && \
emmake make test_models/bernoulli/bernoulli.js -j4 && \
emstrip test_models/bernoulli/bernoulli.wasm

RUN pip install fastapi
Expand Down
5 changes: 2 additions & 3 deletions docker/make/local
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ LDLIBS_TBB ?= -ltbb

# could also uses -fexceptions which is more compatible, but slower
CXXFLAGS+=-fwasm-exceptions
CXXFLAGS+=-g

LDFLAGS+=-sMODULARIZE -sEXPORT_NAME=createModule -sEXPORT_ES6 -sENVIRONMENT=web
LDFLAGS+=-sMODULARIZE -sEXPORT_NAME=createModule -sEXPORT_ES6 -sENVIRONMENT=web -sINCOMING_MODULE_JS_API=print,printErr
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
LDFLAGS+=-sEXIT_RUNTIME=1 -sALLOW_MEMORY_GROWTH=1
# Functions we want. Can add more, with a prepended _, from tinystan.h
EXPORTS=_malloc,_free,_tinystan_api_version,_tinystan_create_model,_tinystan_destroy_error,_tinystan_destroy_model,_tinystan_get_error_message,_tinystan_get_error_type,_tinystan_model_num_free_params,_tinystan_model_param_names,_tinystan_sample,_tinystan_separator_char,_tinystan_stan_version
EXPORTS=_malloc,_free,_tinystan_api_version,_tinystan_create_model,_tinystan_destroy_error,_tinystan_destroy_model,_tinystan_get_error_message,_tinystan_get_error_type,_tinystan_model_num_free_params,_tinystan_model_param_names,_tinystan_sample,_tinystan_pathfinder,_tinystan_separator_char,_tinystan_stan_version
LDFLAGS+=-sEXPORTED_FUNCTIONS=$(EXPORTS) -sEXPORTED_RUNTIME_METHODS=stringToUTF8,getValue,UTF8ToString,lengthBytesUTF8

2 changes: 1 addition & 1 deletion gui/src/app/StanSampler/StanSampler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class StanSampler {
this.#onStatusChangedCallbacks.forEach(cb => cb())
break;
}
case Replies.SampleReturn: {
case Replies.StanReturn: {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own clarification, what's the distinction being made here? Is it "pathfinder doesn't technically do sampling but will still use this verb so we don't want to call it Sample"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, whether or not you can still call the result of an approximation "draws" or "samples" is a bit sticky.

I also think the name previously was incorrect for when the return is actually an error

if (e.data.error) {
this.#errorMessage = e.data.error;
this.#status = 'failed';
Expand Down
23 changes: 19 additions & 4 deletions gui/src/app/tinystan/Worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ import StanModel from ".";
export enum Requests {
Load = "load",
Sample = "sample",
Pathfinder = "pathfinder",
}

export enum Replies {
ModelLoaded = "modelLoaded",
SampleReturn = "sampleReturn",
StanReturn = "stanReturn",
Progress = "progress",
}

Expand Down Expand Up @@ -62,15 +63,29 @@ onmessage = function (e) {
}
case Requests.Sample: {
if (!model) {
postMessage({ purpose: Replies.SampleReturn, error: "Model not loaded yet!" })
postMessage({ purpose: Replies.StanReturn, error: "Model not loaded yet!" })
return;
}
try {
const { paramNames, draws } = model.sample(e.data.sampleConfig);
// TODO? use an ArrayBuffer so we can transfer without serialization cost
postMessage({ purpose: Replies.SampleReturn, draws, paramNames, error: null });
postMessage({ purpose: Replies.StanReturn, draws, paramNames, error: null });
} catch (e: any) {
postMessage({ purpose: Replies.SampleReturn, error: e.toString() })
postMessage({ purpose: Replies.StanReturn, error: e.toString() })
}
break;
}
case Requests.Pathfinder: {
if (!model) {
postMessage({ purpose: Replies.StanReturn, error: "Model not loaded yet!" })
return;
}
try {
const { draws, paramNames } = model.pathfinder(e.data.pathfinderConfig);
// TODO? use an ArrayBuffer so we can transfer without serialization cost
postMessage({ purpose: Replies.StanReturn, draws, paramNames, error: null });
} catch (e: any) {
postMessage({ purpose: Replies.StanReturn, error: e.toString() })
}
break;
}
Expand Down
183 changes: 176 additions & 7 deletions gui/src/app/tinystan/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ interface WasmModule {
term_buffer: number, window: number, save_warmup: number, stepsize: number, stepsize_jitter: number,
max_depth: number, refresh: number, num_threads: number, out: ptr, out_size: number, metric_out: ptr,
err_ptr: ptr): number;
// prettier-ignore
_tinystan_pathfinder(model: model_ptr, num_paths: number, inits: cstr, seed: number, id: number,
init_radius: number, num_draws: number, max_history_size: number, init_alpha: number, tol_obj: number,
tol_rel_obj: number, tol_grad: number, tol_rel_grad: number, tol_param: number, num_iterations: number,
num_elbo_draws: number, num_multi_draws: number, calculate_lp: number, psis_resample: number,
refresh: number, num_threads: number, out: ptr, out_size: number, err_ptr: ptr): number;
_tinystan_get_error_message(err_ptr: error_ptr): cstr;
_tinystan_get_error_type(err_ptr: error_ptr): number;
_tinystan_destroy_error(err_ptr: error_ptr): void;
Expand All @@ -38,6 +44,8 @@ interface WasmModule {

const NULL = 0 as ptr;

const PTR_SIZE = 4;

const HMC_SAMPLER_VARIABLES = [
"lp__",
"accept_stat__",
Expand All @@ -54,6 +62,8 @@ export enum HMCMetric {
DIAGONAL = 2,
}

const PATHFINDER_VARIABLES = ["lp_approx__", "lp__"];

export type PrintCallback = (s: string) => void;

export type StanDraws = {
Expand Down Expand Up @@ -119,6 +129,59 @@ const defaultSamplerParams: SamplerParams = {
num_threads: -1,
};

interface LBFGSConfig {
max_history_size: number;
init_alpha: number;
tol_obj: number;
tol_rel_obj: number;
tol_grad: number;
tol_rel_grad: number;
tol_param: number;
num_iterations: number;
}

interface PathfinderUniqueParams {
data: string | StanVariableInputs;
num_paths: number;
inits: string | StanVariableInputs | string[] | StanVariableInputs[];
seed: number | null;
id: number;
init_radius: number;
num_draws: number;
num_elbo_draws: number;
num_multi_draws: number;
calculate_lp: boolean;
psis_resample: boolean;
refresh: number;
num_threads: number;
}

export type PathfinderParams = LBFGSConfig & PathfinderUniqueParams;

const defaultPathfinderParams: PathfinderParams = {
data: "",
num_paths: 4,
inits: "",
seed: null,
id: 1,
init_radius: 2.0,
num_draws: 1000,
max_history_size: 5,
init_alpha: 0.001,
tol_obj: 1e-12,
tol_rel_obj: 1e4,
tol_grad: 1e-8,
tol_rel_grad: 1e7,
tol_param: 1e-8,
num_iterations: 1000,
num_elbo_draws: 25,
num_multi_draws: 1000,
calculate_lp: true,
psis_resample: true,
refresh: 100,
num_threads: -1,
};

export default class StanModel {
private m: WasmModule;
private printCallback: PrintCallback | null;
Expand Down Expand Up @@ -190,7 +253,7 @@ export default class StanModel {
f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T,
): T {
const data_ptr = this.encodeString(string_safe_jsonify(data));
const err_ptr = this.m._malloc(4);
const err_ptr = this.m._malloc(PTR_SIZE);
const model = this.m._tinystan_create_model(data_ptr, seed, err_ptr);
this.m._free(data_ptr);

Expand Down Expand Up @@ -248,10 +311,7 @@ export default class StanModel {
throw new Error("num_samples must be at least 1");
}

let seed_ = seed;
if (seed_ === null) {
seed_ = Math.floor(Math.random() * (2 ^ 32));
}
const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32));

return this.withModel(data, seed_, (model, deferredFree) => {
// Get the parameter names
Expand All @@ -264,7 +324,7 @@ export default class StanModel {

const free_params = this.m._tinystan_model_num_free_params(model);
if (free_params === 0) {
throw new Error("No parameters to sample");
throw new Error("Model has no parameters to sample.");
}

// TODO: allow init_inv_metric to be specified
Expand Down Expand Up @@ -297,7 +357,7 @@ export default class StanModel {
const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);
deferredFree(out_ptr);

const err_ptr = this.m._malloc(4);
const err_ptr = this.m._malloc(PTR_SIZE);
deferredFree(err_ptr);

// Sample from the model
Expand Down Expand Up @@ -386,6 +446,115 @@ export default class StanModel {
});
}

public pathfinder(p: Partial<PathfinderParams>): StanDraws {
const {
data,
num_paths,
inits,
seed,
id,
init_radius,
num_draws,
max_history_size,
init_alpha,
tol_obj,
tol_rel_obj,
tol_grad,
tol_rel_grad,
tol_param,
num_iterations,
num_elbo_draws,
num_multi_draws,
calculate_lp,
psis_resample,
refresh,
num_threads,
} = { ...defaultPathfinderParams, ...p };

if (num_paths < 1) {
throw new Error("num_paths must be at least 1");
}
if (num_draws < 1) {
throw new Error("num_draws must be at least 1");
}
if (num_multi_draws < 1) {
throw new Error("num_multi_draws must be at least 1");
}

const output_rows =
calculate_lp && psis_resample ? num_multi_draws : num_draws * num_paths;

const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32));

return this.withModel(data, seed_, (model, deferredFree) => {
const rawParamNames = this.m.UTF8ToString(
this.m._tinystan_model_param_names(model),
);
const paramNames = PATHFINDER_VARIABLES.concat(rawParamNames.split(","));

const n_params = paramNames.length;

const free_params = this.m._tinystan_model_num_free_params(model);
if (free_params === 0) {
throw new Error("Model has no parameters.");
}

const inits_ptr = this.encodeInits(inits);
deferredFree(inits_ptr);

const n_out = output_rows * n_params;
const out = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT);
deferredFree(out);
const err_ptr = this.m._malloc(PTR_SIZE);
deferredFree(err_ptr);

const result = this.m._tinystan_pathfinder(
model,
num_paths,
inits_ptr,
seed_ || 0,
id,
init_radius,
num_draws,
max_history_size,
init_alpha,
tol_obj,
tol_rel_obj,
tol_grad,
tol_rel_grad,
tol_param,
num_iterations,
num_elbo_draws,
num_multi_draws,
calculate_lp ? 1 : 0,
psis_resample ? 1 : 0,
refresh,
num_threads,
out,
n_out,
err_ptr,
);

if (result != 0) {
this.handleError(err_ptr);
}

const out_buffer = this.m.HEAPF64.subarray(
out / Float64Array.BYTES_PER_ELEMENT,
out / Float64Array.BYTES_PER_ELEMENT + n_out,
);

const draws: number[][] = Array.from({ length: n_params }, (_, i) =>
Array.from(
{ length: output_rows },
(_, j) => out_buffer[i + n_params * j],
),
);

return { paramNames, draws };
});
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to convert the _malloc(4) calls below to use PTR_SIZE instead, since you added it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the 4 in those mallocs is a different 4 than the ones I changed here (it is not necessarily true that sizeof(int*) == sizeof(int), in fact on x64 systems it is false)

So I could define an INT_SIZE const variable, but those are also the only places I think the API will ever need to malloc an int (in retrospect, I could use Int32Array.BYTES_PER_ELEMENT as something more self-documenting, but also the fact that Int32 is in the name does pretty strongly imply the answer will be 4...)

public stanVersion(): string {
const major = this.m._malloc(4);
const minor = this.m._malloc(4);
Expand Down