44
55import yaml
66from matplotlib import pyplot as plt
7- from mu_ppl import viz
87
98from ._version import version
109from .pdl import InterpreterConfig
11- from .pdl_ast import PdlLocationType , Program , ScopeType , get_default_model_parameters
12- from .pdl_distributions import Categorical
10+ from .pdl_ast import (
11+ PdlLocationType ,
12+ PdlUsage ,
13+ Program ,
14+ ScopeType ,
15+ get_default_model_parameters ,
16+ )
17+ from .pdl_distributions import Categorical , viz
1318from .pdl_inference import (
1419 infer_importance_sampling ,
1520 infer_importance_sampling_parallel ,
@@ -42,8 +47,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
4247 ppdl_config : Optional [PpdlConfig ] = None ,
4348 scope : Optional [ScopeType | dict [str , Any ]] = None ,
4449 loc : Optional [PdlLocationType ] = None ,
45- # output: Literal["result", "all"] = "result",
46- ) -> Categorical [Any ]:
50+ output : Literal ["result" , "all" ] = "result" ,
51+ ) -> Categorical [Any ] | dict :
4752 ppdl_config = ppdl_config or PpdlConfig ()
4853
4954 algo = ppdl_config .get ("algo" , "parallel-smc" )
@@ -54,7 +59,8 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
5459 config ["yield_result" ] = False
5560 config ["yield_background" ] = False
5661 config ["batch" ] = 1
57- config ["event_loop" ] = _LOOP
62+ config ["event_loop" ] = config .get ("event_loop" , _LOOP )
63+ config ["llm_usage" ] = config .get ("llm_usage" , PdlUsage ())
5864
5965 dist : Categorical [Any ]
6066 match algo :
@@ -97,7 +103,21 @@ def exec_program( # pylint: disable=too-many-arguments, too-many-positional-arg
97103 )
98104 case _:
99105 assert False , f"Unexpected algo: { algo } "
100- return dist
106+ match output :
107+ case "result" :
108+ return dist
109+ case "all" :
110+ result_all = {
111+ "result" : dist ,
112+ # "scope": future_scope.result(),
113+ # "trace": trace,
114+ # "replay": state.replay,
115+ # "score": state.score.ref,
116+ "usage" : config ["llm_usage" ],
117+ }
118+ return result_all
119+ case _:
120+ assert False , 'The `output` variable should be "result" or "all"'
101121
102122
103123def exec_dict ( # pylint: disable=too-many-arguments, too-many-positional-arguments
@@ -106,10 +126,10 @@ def exec_dict( # pylint: disable=too-many-arguments, too-many-positional-argume
106126 ppdl_config : Optional [PpdlConfig ] = None ,
107127 scope : Optional [ScopeType | dict [str , Any ]] = None ,
108128 loc : Optional [PdlLocationType ] = None ,
109- # output: Literal["result", "all"] = "result",
110- ) -> Any :
129+ output : Literal ["result" , "all" ] = "result" ,
130+ ) -> Categorical [ Any ] | dict :
111131 program = parse_dict (prog )
112- result = exec_program (program , config , ppdl_config , scope , loc )
132+ result = exec_program (program , config , ppdl_config , scope , loc , output )
113133 return result
114134
115135
@@ -118,10 +138,10 @@ def exec_str(
118138 config : Optional [InterpreterConfig ] = None ,
119139 ppdl_config : Optional [PpdlConfig ] = None ,
120140 scope : Optional [ScopeType | dict [str , Any ]] = None ,
121- # output: Literal["result", "all"] = "result",
122- ) -> Any :
141+ output : Literal ["result" , "all" ] = "result" ,
142+ ) -> Categorical [ Any ] | dict :
123143 program , loc = parse_str (prog )
124- result = exec_program (program , config , ppdl_config , scope , loc )
144+ result = exec_program (program , config , ppdl_config , scope , loc , output )
125145 return result
126146
127147
@@ -130,14 +150,14 @@ def exec_file(
130150 config : Optional [InterpreterConfig ] = None ,
131151 ppdl_config : Optional [PpdlConfig ] = None ,
132152 scope : Optional [ScopeType | dict [str , Any ]] = None ,
133- # output: Literal["result", "all"] = "result",
134- ) -> Any :
153+ output : Literal ["result" , "all" ] = "result" ,
154+ ) -> Categorical [ Any ] | dict :
135155 program , loc = parse_file (prog )
136156 if config is None :
137157 config = InterpreterConfig ()
138158 if config .get ("cwd" ) is None :
139159 config ["cwd" ] = Path (prog ).parent
140- result = exec_program (program , config , ppdl_config , scope , loc )
160+ result = exec_program (program , config , ppdl_config , scope , loc , output )
141161 return result
142162
143163
@@ -224,8 +244,9 @@ def main():
224244 algo = args .algo , num_particles = args .num_particles , max_workers = args .workers
225245 )
226246
227- dist = exec_file (args .pdl , config , ppdl_config , initial_scope )
247+ dist = exec_file (args .pdl , config , ppdl_config , initial_scope , output = "result" )
228248
249+ assert isinstance (dist , Categorical )
229250 if args .viz :
230251 viz (dist )
231252 plt .show ()
0 commit comments