diff --git a/src/solver.py b/src/solver.py index 8347664..b462c00 100644 --- a/src/solver.py +++ b/src/solver.py @@ -45,6 +45,8 @@ LoadLabel, StoreLabel, Lattice, + MaybeVar, + maybe_to_var, ) from .sketches import LabelNode, SketchNode, Sketch from .loggable import Loggable, LogLevel, show_progress @@ -853,6 +855,49 @@ def _solve_top_down( else: actuals_sketch_map[proc] = sketch + def get_type_of_variables( + self, + sketches_map: Dict[DerivedTypeVariable, Sketch], + type_schemes: Dict[DerivedTypeVariable, ConstraintSet], + proc: MaybeVar, + vars: Set[DerivedTypeVariable], + ): + """ + Solving sketches for an additional set of derived type variables. + """ + proc = maybe_to_var(proc) + constraints = self.program.proc_constraints.get( + proc, ConstraintSet() + ) + callees = set(networkx.DiGraph(self.program.callgraph).successors(proc)) + fresh_var_factory = FreshVarFactory() + constraints |= Solver.instantiate_calls( + callees, + constraints, + sketches_map, + type_schemes, + fresh_var_factory, + ) + constraints |= sketches_map[proc].instantiate_sketch( + proc, fresh_var_factory + ) + + var_sketches = self.infer_shapes( + vars, + self.program.types, + constraints, + ) + for var in vars: + primitive_constraints = self._generate_primitive_constraints( + constraints, + frozenset({var}), + self.program.types.internal_types) + + var_sketches[var].add_constraints( + primitive_constraints + ) + return var_sketches + def __call__( self, ) -> Tuple[