Skip to content

Commit 20021c2

Browse files
committed
feat: ppdl can retrun usage stats
Signed-off-by: Louis Mandel <[email protected]>
1 parent 9b9bbfc commit 20021c2

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

src/pdl/pdl_infer.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44

55
import yaml
66
from matplotlib import pyplot as plt
7-
from mu_ppl import viz
87

98
from ._version import version
109
from .pdl import InterpreterConfig
11-
from .pdl_ast import PdlLocationType, Program, ScopeType, get_default_model_parameters
12-
from .pdl_distributions import Categorical
10+
from .pdl_ast import (
11+
PdlLocationType,
12+
PdlUsage,
13+
Program,
14+
ScopeType,
15+
get_default_model_parameters,
16+
)
17+
from .pdl_distributions import Categorical, viz
1318
from .pdl_inference import (
1419
infer_importance_sampling,
1520
infer_importance_sampling_parallel,
@@ -42,8 +47,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
4247
ppdl_config: Optional[PpdlConfig] = None,
4348
scope: Optional[ScopeType | dict[str, Any]] = None,
4449
loc: Optional[PdlLocationType] = None,
45-
# output: Literal["result", "all"] = "result",
46-
) -> Categorical[Any]:
50+
output: Literal["result", "all"] = "result",
51+
) -> Categorical[Any] | dict:
4752
ppdl_config = ppdl_config or PpdlConfig()
4853

4954
algo = ppdl_config.get("algo", "parallel-smc")
@@ -54,7 +59,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
5459
config["yield_result"] = False
5560
config["yield_background"] = False
5661
config["batch"] = 1
57-
config["event_loop"] = _LOOP
62+
config["event_loop"] = config.get("event_loop", _LOOP)
63+
config["llm_usage"] = config.get("llm_usage", PdlUsage())
5864

5965
dist: Categorical[Any]
6066
match algo:
@@ -97,7 +103,21 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
97103
)
98104
case _:
99105
assert False, f"Unexpected algo: {algo}"
100-
return dist
106+
match output:
107+
case "result":
108+
return dist
109+
case "all":
110+
result_all = {
111+
"result": dist,
112+
# "scope": future_scope.result(),
113+
# "trace": trace,
114+
# "replay": state.replay,
115+
# "score": state.score.ref,
116+
"usage": config["llm_usage"],
117+
}
118+
return result_all
119+
case _:
120+
assert False, 'The `output` variable should be "result" or "all"'
101121

102122

103123
def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-arguments
@@ -106,10 +126,10 @@ def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-argume
106126
ppdl_config: Optional[PpdlConfig] = None,
107127
scope: Optional[ScopeType | dict[str, Any]] = None,
108128
loc: Optional[PdlLocationType] = None,
109-
# output: Literal["result", "all"] = "result",
110-
) -> Any:
129+
output: Literal["result", "all"] = "result",
130+
) -> Categorical[Any] | dict:
111131
program = parse_dict(prog)
112-
result = exec_program(program, config, ppdl_config, scope, loc)
132+
result = exec_program(program, config, ppdl_config, scope, loc, output)
113133
return result
114134

115135

@@ -118,10 +138,10 @@ def exec_str(
118138
config: Optional[InterpreterConfig] = None,
119139
ppdl_config: Optional[PpdlConfig] = None,
120140
scope: Optional[ScopeType | dict[str, Any]] = None,
121-
# output: Literal["result", "all"] = "result",
122-
) -> Any:
141+
output: Literal["result", "all"] = "result",
142+
) -> Categorical[Any] | dict:
123143
program, loc = parse_str(prog)
124-
result = exec_program(program, config, ppdl_config, scope, loc)
144+
result = exec_program(program, config, ppdl_config, scope, loc, output)
125145
return result
126146

127147

@@ -130,14 +150,14 @@ def exec_file(
130150
config: Optional[InterpreterConfig] = None,
131151
ppdl_config: Optional[PpdlConfig] = None,
132152
scope: Optional[ScopeType | dict[str, Any]] = None,
133-
# output: Literal["result", "all"] = "result",
134-
) -> Any:
153+
output: Literal["result", "all"] = "result",
154+
) -> Categorical[Any] | dict:
135155
program, loc = parse_file(prog)
136156
if config is None:
137157
config = InterpreterConfig()
138158
if config.get("cwd") is None:
139159
config["cwd"] = Path(prog).parent
140-
result = exec_program(program, config, ppdl_config, scope, loc)
160+
result = exec_program(program, config, ppdl_config, scope, loc, output)
141161
return result
142162

143163

@@ -224,8 +244,9 @@ def main():
224244
algo=args.algo, num_particles=args.num_particles, max_workers=args.workers
225245
)
226246

227-
dist = exec_file(args.pdl, config, ppdl_config, initial_scope)
247+
dist = exec_file(args.pdl, config, ppdl_config, initial_scope, output="result")
228248

249+
assert isinstance(dist, Categorical)
229250
if args.viz:
230251
viz(dist)
231252
plt.show()

0 commit comments

Comments
 (0)