@@ -61,11 +61,22 @@ fn reachable_funcs<'a, H: HugrView>(
6161 } ) )
6262}
6363
64- #[ derive( Debug , Clone , Default ) ]
64+ #[ derive( Debug , Clone ) ]
6565/// A configuration for the Dead Function Removal pass.
6666pub struct RemoveDeadFuncsPass {
6767 validation : ValidationLevel ,
6868 entry_points : Vec < Node > ,
69+ include_exports : bool ,
70+ }
71+
72+ impl Default for RemoveDeadFuncsPass {
73+ fn default ( ) -> Self {
74+ Self {
75+ validation : Default :: default ( ) ,
76+ entry_points : Default :: default ( ) ,
77+ include_exports : true ,
78+ }
79+ }
6980}
7081
7182impl RemoveDeadFuncsPass {
@@ -88,10 +99,28 @@ impl RemoveDeadFuncsPass {
8899 self
89100 }
90101
102+ /// Sets whether the exported [FuncDefn](hugr_core::ops::FuncDefn) children of a
103+ /// [Module](hugr_core::ops::Module) are included as entry points (yes by default)
104+ pub fn include_module_exports ( mut self , include : bool ) -> Self {
105+ self . include_exports = include;
106+ self
107+ }
108+
91109 /// Runs the pass (see [remove_dead_funcs]) with this configuration
92110 pub fn run < H : HugrMut > ( & self , hugr : & mut H ) -> Result < ( ) , RemoveDeadFuncsError > {
93111 self . validation . run_validated_pass ( hugr, |hugr : & mut H , _| {
94- remove_dead_funcs ( hugr, self . entry_points . iter ( ) . cloned ( ) )
112+ let exports = if hugr. root_type ( ) . is_module ( ) && self . include_exports {
113+ hugr. children ( hugr. root ( ) )
114+ . filter ( |ch| {
115+ hugr. get_optype ( * ch)
116+ . as_func_defn ( )
117+ . is_some_and ( |fd| fd. public )
118+ } )
119+ . collect ( )
120+ } else {
121+ vec ! [ ]
122+ } ;
123+ remove_dead_funcs ( hugr, self . entry_points . iter ( ) . cloned ( ) . chain ( exports) )
95124 } )
96125 }
97126}
@@ -145,26 +174,29 @@ mod test {
145174 use super :: RemoveDeadFuncsPass ;
146175
147176 #[ rstest]
148- #[ case( [ ] , vec![ ] ) ] // No entry_points removes everything!
149- #[ case( [ "main" ] , vec![ "from_main" , "main" ] ) ]
150- #[ case( [ "from_main" ] , vec![ "from_main" ] ) ]
151- #[ case( [ "other1" ] , vec![ "other1" , "other2" ] ) ]
152- #[ case( [ "other2" ] , vec![ "other2" ] ) ]
153- #[ case( [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
177+ #[ case( false , [ ] , vec![ ] ) ] // No entry_points removes everything!
178+ #[ case( false , [ "main" ] , vec![ "from_main" , "main" ] ) ]
179+ #[ case( false , [ "from_main" ] , vec![ "from_main" ] ) ]
180+ #[ case( false , [ "other1" ] , vec![ "other1" , "other2" ] ) ]
181+ #[ case( false , [ "other2" ] , vec![ "other2" ] ) ]
182+ #[ case( false , [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
183+ #[ case( true , [ ] , vec![ "from_main" , "main" , "other2" ] ) ]
184+ #[ case( true , [ "other1" ] , vec![ "from_main" , "main" , "other1" , "other2" ] ) ]
154185 fn remove_dead_funcs_entry_points (
186+ #[ case] include_exports : bool ,
155187 #[ case] entry_points : impl IntoIterator < Item = & ' static str > ,
156188 #[ case] retained_funcs : Vec < & ' static str > ,
157189 ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
158190 let mut hb = ModuleBuilder :: new ( ) ;
159191 let o2 = hb. define_function ( "other2" , Signature :: new_endo ( usize_t ( ) ) ) ?;
160192 let o2inp = o2. input_wires ( ) ;
161193 let o2 = o2. finish_with_outputs ( o2inp) ?;
162- let mut o1 = hb. define_function ( "other1" , Signature :: new_endo ( usize_t ( ) ) ) ?;
194+ let mut o1 = hb. define_function_vis ( "other1" , Signature :: new_endo ( usize_t ( ) ) , false ) ?;
163195
164196 let o1c = o1. call ( o2. handle ( ) , & [ ] , o1. input_wires ( ) ) ?;
165197 o1. finish_with_outputs ( o1c. outputs ( ) ) ?;
166198
167- let fm = hb. define_function ( "from_main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
199+ let fm = hb. define_function_vis ( "from_main" , Signature :: new_endo ( usize_t ( ) ) , false ) ?;
168200 let f_inp = fm. input_wires ( ) ;
169201 let fm = fm. finish_with_outputs ( f_inp) ?;
170202 let mut m = hb. define_function ( "main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
@@ -183,6 +215,7 @@ mod test {
183215 . collect :: < HashMap < _ , _ > > ( ) ;
184216
185217 RemoveDeadFuncsPass :: default ( )
218+ . include_module_exports ( include_exports)
186219 . with_module_entry_points (
187220 entry_points
188221 . into_iter ( )
0 commit comments