diff --git a/src/pdl/pdl_infer.py b/src/pdl/pdl_infer.py index a0a3ca030..5b7b4ac66 100644 --- a/src/pdl/pdl_infer.py +++ b/src/pdl/pdl_infer.py @@ -4,12 +4,17 @@ import yaml from matplotlib import pyplot as plt -from mu_ppl import viz from ._version import version from .pdl import InterpreterConfig -from .pdl_ast import PdlLocationType, Program, ScopeType, get_default_model_parameters -from .pdl_distributions import Categorical +from .pdl_ast import ( + PdlLocationType, + PdlUsage, + Program, + ScopeType, + get_default_model_parameters, +) +from .pdl_distributions import Categorical, viz from .pdl_inference import ( infer_importance_sampling, infer_importance_sampling_parallel, @@ -42,8 +47,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg ppdl_config: Optional[PpdlConfig] = None, scope: Optional[ScopeType | dict[str, Any]] = None, loc: Optional[PdlLocationType] = None, - # output: Literal["result", "all"] = "result", -) -> Categorical[Any]: + output: Literal["result", "all"] = "result", +) -> Categorical[Any] | dict: ppdl_config = ppdl_config or PpdlConfig() algo = ppdl_config.get("algo", "parallel-smc") @@ -54,7 +59,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg config["yield_result"] = False config["yield_background"] = False config["batch"] = 1 - config["event_loop"] = _LOOP + config["event_loop"] = config.get("event_loop", _LOOP) + config["llm_usage"] = config.get("llm_usage", PdlUsage()) dist: Categorical[Any] match algo: @@ -97,7 +103,21 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg ) case _: assert False, f"Unexpected algo: {algo}" - return dist + match output: + case "result": + return dist + case "all": + result_all = { + "result": dist, + # "scope": future_scope.result(), + # "trace": trace, + # "replay": state.replay, + # "score": state.score.ref, + "usage": config["llm_usage"], + } + return result_all + case _: + assert False, 'The `output` variable should be "result" or "all"' def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-arguments @@ -106,10 +126,10 @@ def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-argume ppdl_config: Optional[PpdlConfig] = None, scope: Optional[ScopeType | dict[str, Any]] = None, loc: Optional[PdlLocationType] = None, - # output: Literal["result", "all"] = "result", -) -> Any: + output: Literal["result", "all"] = "result", +) -> Categorical[Any] | dict: program = parse_dict(prog) - result = exec_program(program, config, ppdl_config, scope, loc) + result = exec_program(program, config, ppdl_config, scope, loc, output) return result @@ -118,10 +138,10 @@ def exec_str( config: Optional[InterpreterConfig] = None, ppdl_config: Optional[PpdlConfig] = None, scope: Optional[ScopeType | dict[str, Any]] = None, - # output: Literal["result", "all"] = "result", -) -> Any: + output: Literal["result", "all"] = "result", +) -> Categorical[Any] | dict: program, loc = parse_str(prog) - result = exec_program(program, config, ppdl_config, scope, loc) + result = exec_program(program, config, ppdl_config, scope, loc, output) return result @@ -130,14 +150,14 @@ def exec_file( config: Optional[InterpreterConfig] = None, ppdl_config: Optional[PpdlConfig] = None, scope: Optional[ScopeType | dict[str, Any]] = None, - # output: Literal["result", "all"] = "result", -) -> Any: + output: Literal["result", "all"] = "result", +) -> Categorical[Any] | dict: program, loc = parse_file(prog) if config is None: config = InterpreterConfig() if config.get("cwd") is None: config["cwd"] = Path(prog).parent - result = exec_program(program, config, ppdl_config, scope, loc) + result = exec_program(program, config, ppdl_config, scope, loc, output) return result @@ -224,8 +244,9 @@ def main(): algo=args.algo, num_particles=args.num_particles, max_workers=args.workers ) - dist = exec_file(args.pdl, config, ppdl_config, initial_scope) + dist = exec_file(args.pdl, config, ppdl_config, initial_scope, output="result") + assert isinstance(dist, Categorical) if args.viz: viz(dist) plt.show()