Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions src/pdl/pdl_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@

import yaml
from matplotlib import pyplot as plt
from mu_ppl import viz

from ._version import version
from .pdl import InterpreterConfig
from .pdl_ast import PdlLocationType, Program, ScopeType, get_default_model_parameters
from .pdl_distributions import Categorical
from .pdl_ast import (
PdlLocationType,
PdlUsage,
Program,
ScopeType,
get_default_model_parameters,
)
from .pdl_distributions import Categorical, viz
from .pdl_inference import (
infer_importance_sampling,
infer_importance_sampling_parallel,
Expand Down Expand Up @@ -42,8 +47,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
ppdl_config: Optional[PpdlConfig] = None,
scope: Optional[ScopeType | dict[str, Any]] = None,
loc: Optional[PdlLocationType] = None,
# output: Literal["result", "all"] = "result",
) -> Categorical[Any]:
output: Literal["result", "all"] = "result",
) -> Categorical[Any] | dict:
ppdl_config = ppdl_config or PpdlConfig()

algo = ppdl_config.get("algo", "parallel-smc")
Expand All @@ -54,7 +59,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
config["yield_result"] = False
config["yield_background"] = False
config["batch"] = 1
config["event_loop"] = _LOOP
config["event_loop"] = config.get("event_loop", _LOOP)
config["llm_usage"] = config.get("llm_usage", PdlUsage())

dist: Categorical[Any]
match algo:
Expand Down Expand Up @@ -97,7 +103,21 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
)
case _:
assert False, f"Unexpected algo: {algo}"
return dist
match output:
case "result":
return dist
case "all":
result_all = {
"result": dist,
# "scope": future_scope.result(),
# "trace": trace,
# "replay": state.replay,
# "score": state.score.ref,
"usage": config["llm_usage"],
}
return result_all
case _:
assert False, 'The `output` variable should be "result" or "all"'


def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-arguments
Expand All @@ -106,10 +126,10 @@ def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-argume
ppdl_config: Optional[PpdlConfig] = None,
scope: Optional[ScopeType | dict[str, Any]] = None,
loc: Optional[PdlLocationType] = None,
# output: Literal["result", "all"] = "result",
) -> Any:
output: Literal["result", "all"] = "result",
) -> Categorical[Any] | dict:
program = parse_dict(prog)
result = exec_program(program, config, ppdl_config, scope, loc)
result = exec_program(program, config, ppdl_config, scope, loc, output)
return result


Expand All @@ -118,10 +138,10 @@ def exec_str(
config: Optional[InterpreterConfig] = None,
ppdl_config: Optional[PpdlConfig] = None,
scope: Optional[ScopeType | dict[str, Any]] = None,
# output: Literal["result", "all"] = "result",
) -> Any:
output: Literal["result", "all"] = "result",
) -> Categorical[Any] | dict:
program, loc = parse_str(prog)
result = exec_program(program, config, ppdl_config, scope, loc)
result = exec_program(program, config, ppdl_config, scope, loc, output)
return result


Expand All @@ -130,14 +150,14 @@ def exec_file(
config: Optional[InterpreterConfig] = None,
ppdl_config: Optional[PpdlConfig] = None,
scope: Optional[ScopeType | dict[str, Any]] = None,
# output: Literal["result", "all"] = "result",
) -> Any:
output: Literal["result", "all"] = "result",
) -> Categorical[Any] | dict:
program, loc = parse_file(prog)
if config is None:
config = InterpreterConfig()
if config.get("cwd") is None:
config["cwd"] = Path(prog).parent
result = exec_program(program, config, ppdl_config, scope, loc)
result = exec_program(program, config, ppdl_config, scope, loc, output)
return result


Expand Down Expand Up @@ -224,8 +244,9 @@ def main():
algo=args.algo, num_particles=args.num_particles, max_workers=args.workers
)

dist = exec_file(args.pdl, config, ppdl_config, initial_scope)
dist = exec_file(args.pdl, config, ppdl_config, initial_scope, output="result")

assert isinstance(dist, Categorical)
if args.viz:
viz(dist)
plt.show()
Expand Down