-
Notifications
You must be signed in to change notification settings - Fork 1
Basic support for calling Pathfinder #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,7 @@ class StanSampler { | |
this.#onStatusChangedCallbacks.forEach(cb => cb()) | ||
break; | ||
} | ||
case Replies.SampleReturn: { | ||
case Replies.StanReturn: { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my own clarification, what's the distinction being made here? Is it "pathfinder doesn't technically do sampling but will still use this verb so we don't want to call it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, whether or not you can still call the result of an approximation "draws" or "samples" is a bit sticky. I also think the name previously was incorrect for when the return is actually an error |
||
if (e.data.error) { | ||
this.#errorMessage = e.data.error; | ||
this.#status = 'failed'; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,12 @@ interface WasmModule { | |
term_buffer: number, window: number, save_warmup: number, stepsize: number, stepsize_jitter: number, | ||
max_depth: number, refresh: number, num_threads: number, out: ptr, out_size: number, metric_out: ptr, | ||
err_ptr: ptr): number; | ||
// prettier-ignore | ||
_tinystan_pathfinder(model: model_ptr, num_paths: number, inits: cstr, seed: number, id: number, | ||
init_radius: number, num_draws: number, max_history_size: number, init_alpha: number, tol_obj: number, | ||
tol_rel_obj: number, tol_grad: number, tol_rel_grad: number, tol_param: number, num_iterations: number, | ||
num_elbo_draws: number, num_multi_draws: number, calculate_lp: number, psis_resample: number, | ||
refresh: number, num_threads: number, out: ptr, out_size: number, err_ptr: ptr): number; | ||
_tinystan_get_error_message(err_ptr: error_ptr): cstr; | ||
_tinystan_get_error_type(err_ptr: error_ptr): number; | ||
_tinystan_destroy_error(err_ptr: error_ptr): void; | ||
|
@@ -38,6 +44,8 @@ interface WasmModule { | |
|
||
const NULL = 0 as ptr; | ||
|
||
const PTR_SIZE = 4; | ||
|
||
const HMC_SAMPLER_VARIABLES = [ | ||
"lp__", | ||
"accept_stat__", | ||
|
@@ -54,6 +62,8 @@ export enum HMCMetric { | |
DIAGONAL = 2, | ||
} | ||
|
||
const PATHFINDER_VARIABLES = ["lp_approx__", "lp__"]; | ||
|
||
export type PrintCallback = (s: string) => void; | ||
|
||
export type StanDraws = { | ||
|
@@ -119,6 +129,59 @@ const defaultSamplerParams: SamplerParams = { | |
num_threads: -1, | ||
}; | ||
|
||
interface LBFGSConfig { | ||
max_history_size: number; | ||
init_alpha: number; | ||
tol_obj: number; | ||
tol_rel_obj: number; | ||
tol_grad: number; | ||
tol_rel_grad: number; | ||
tol_param: number; | ||
num_iterations: number; | ||
} | ||
|
||
interface PathfinderUniqueParams { | ||
data: string | StanVariableInputs; | ||
num_paths: number; | ||
inits: string | StanVariableInputs | string[] | StanVariableInputs[]; | ||
seed: number | null; | ||
id: number; | ||
init_radius: number; | ||
num_draws: number; | ||
num_elbo_draws: number; | ||
num_multi_draws: number; | ||
calculate_lp: boolean; | ||
psis_resample: boolean; | ||
refresh: number; | ||
num_threads: number; | ||
} | ||
|
||
export type PathfinderParams = LBFGSConfig & PathfinderUniqueParams; | ||
|
||
const defaultPathfinderParams: PathfinderParams = { | ||
data: "", | ||
num_paths: 4, | ||
inits: "", | ||
seed: null, | ||
id: 1, | ||
init_radius: 2.0, | ||
num_draws: 1000, | ||
max_history_size: 5, | ||
init_alpha: 0.001, | ||
tol_obj: 1e-12, | ||
tol_rel_obj: 1e4, | ||
tol_grad: 1e-8, | ||
tol_rel_grad: 1e7, | ||
tol_param: 1e-8, | ||
num_iterations: 1000, | ||
num_elbo_draws: 25, | ||
num_multi_draws: 1000, | ||
calculate_lp: true, | ||
psis_resample: true, | ||
refresh: 100, | ||
num_threads: -1, | ||
}; | ||
|
||
export default class StanModel { | ||
private m: WasmModule; | ||
private printCallback: PrintCallback | null; | ||
|
@@ -190,7 +253,7 @@ export default class StanModel { | |
f: (model: model_ptr, deferredFree: (p: ptr | cstr) => void) => T, | ||
): T { | ||
const data_ptr = this.encodeString(string_safe_jsonify(data)); | ||
const err_ptr = this.m._malloc(4); | ||
const err_ptr = this.m._malloc(PTR_SIZE); | ||
const model = this.m._tinystan_create_model(data_ptr, seed, err_ptr); | ||
this.m._free(data_ptr); | ||
|
||
|
@@ -248,10 +311,7 @@ export default class StanModel { | |
throw new Error("num_samples must be at least 1"); | ||
} | ||
|
||
let seed_ = seed; | ||
if (seed_ === null) { | ||
seed_ = Math.floor(Math.random() * (2 ^ 32)); | ||
} | ||
const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32)); | ||
|
||
return this.withModel(data, seed_, (model, deferredFree) => { | ||
// Get the parameter names | ||
|
@@ -264,7 +324,7 @@ export default class StanModel { | |
|
||
const free_params = this.m._tinystan_model_num_free_params(model); | ||
if (free_params === 0) { | ||
throw new Error("No parameters to sample"); | ||
throw new Error("Model has no parameters to sample."); | ||
} | ||
|
||
// TODO: allow init_inv_metric to be specified | ||
|
@@ -297,7 +357,7 @@ export default class StanModel { | |
const out_ptr = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT); | ||
deferredFree(out_ptr); | ||
|
||
const err_ptr = this.m._malloc(4); | ||
const err_ptr = this.m._malloc(PTR_SIZE); | ||
deferredFree(err_ptr); | ||
|
||
// Sample from the model | ||
|
@@ -386,6 +446,115 @@ export default class StanModel { | |
}); | ||
} | ||
|
||
public pathfinder(p: Partial<PathfinderParams>): StanDraws { | ||
const { | ||
data, | ||
num_paths, | ||
inits, | ||
seed, | ||
id, | ||
init_radius, | ||
num_draws, | ||
max_history_size, | ||
init_alpha, | ||
tol_obj, | ||
tol_rel_obj, | ||
tol_grad, | ||
tol_rel_grad, | ||
tol_param, | ||
num_iterations, | ||
num_elbo_draws, | ||
num_multi_draws, | ||
calculate_lp, | ||
psis_resample, | ||
refresh, | ||
num_threads, | ||
} = { ...defaultPathfinderParams, ...p }; | ||
|
||
if (num_paths < 1) { | ||
throw new Error("num_paths must be at least 1"); | ||
} | ||
if (num_draws < 1) { | ||
throw new Error("num_draws must be at least 1"); | ||
} | ||
if (num_multi_draws < 1) { | ||
throw new Error("num_multi_draws must be at least 1"); | ||
} | ||
|
||
const output_rows = | ||
calculate_lp && psis_resample ? num_multi_draws : num_draws * num_paths; | ||
|
||
const seed_ = seed !== null ? seed : Math.floor(Math.random() * (2 ^ 32)); | ||
|
||
return this.withModel(data, seed_, (model, deferredFree) => { | ||
const rawParamNames = this.m.UTF8ToString( | ||
this.m._tinystan_model_param_names(model), | ||
); | ||
const paramNames = PATHFINDER_VARIABLES.concat(rawParamNames.split(",")); | ||
|
||
const n_params = paramNames.length; | ||
|
||
const free_params = this.m._tinystan_model_num_free_params(model); | ||
if (free_params === 0) { | ||
throw new Error("Model has no parameters."); | ||
} | ||
|
||
const inits_ptr = this.encodeInits(inits); | ||
deferredFree(inits_ptr); | ||
|
||
const n_out = output_rows * n_params; | ||
const out = this.m._malloc(n_out * Float64Array.BYTES_PER_ELEMENT); | ||
deferredFree(out); | ||
const err_ptr = this.m._malloc(PTR_SIZE); | ||
deferredFree(err_ptr); | ||
|
||
const result = this.m._tinystan_pathfinder( | ||
model, | ||
num_paths, | ||
inits_ptr, | ||
seed_ || 0, | ||
id, | ||
init_radius, | ||
num_draws, | ||
max_history_size, | ||
init_alpha, | ||
tol_obj, | ||
tol_rel_obj, | ||
tol_grad, | ||
tol_rel_grad, | ||
tol_param, | ||
num_iterations, | ||
num_elbo_draws, | ||
num_multi_draws, | ||
calculate_lp ? 1 : 0, | ||
psis_resample ? 1 : 0, | ||
refresh, | ||
num_threads, | ||
out, | ||
n_out, | ||
err_ptr, | ||
); | ||
|
||
if (result != 0) { | ||
this.handleError(err_ptr); | ||
} | ||
|
||
const out_buffer = this.m.HEAPF64.subarray( | ||
out / Float64Array.BYTES_PER_ELEMENT, | ||
out / Float64Array.BYTES_PER_ELEMENT + n_out, | ||
); | ||
|
||
const draws: number[][] = Array.from({ length: n_params }, (_, i) => | ||
Array.from( | ||
{ length: output_rows }, | ||
(_, j) => out_buffer[i + n_params * j], | ||
), | ||
); | ||
|
||
return { paramNames, draws }; | ||
}); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might want to convert the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the 4 in those mallocs is a different 4 than the ones I changed here (it is not necessarily true that So I could define an INT_SIZE const variable, but those are also the only places I think the API will ever need to malloc an int (in retrospect, I could use |
||
public stanVersion(): string { | ||
const major = this.m._malloc(4); | ||
const minor = this.m._malloc(4); | ||
|
Uh oh!
There was an error while loading. Please reload this page.