diff --git a/blendsql/_sqlglot.py b/blendsql/_sqlglot.py index 6d63c926..f9cb0812 100644 --- a/blendsql/_sqlglot.py +++ b/blendsql/_sqlglot.py @@ -724,6 +724,9 @@ def infer_gen_constraints(self, start: int, end: int) -> dict: - pattern: regular expression pattern lambda to use in constrained decoding with Model - See `create_pattern` for more info on these pattern lambdas + + - options: Optional str default to pass to `options` argument in a QAIngredient + - Will have the form '{table}::{column}' """ def create_pattern( @@ -763,6 +766,12 @@ def create_pattern( predicate_literals: List[str] = [] if start_node is not None: predicate_literals = get_predicate_literals(start_node) + if isinstance(start_node, exp.EQ): + if isinstance(start_node.args["this"], exp.Column): + # This is valid for a default `options` set + added_kwargs[ + "options" + ] = f"{start_node.args['this'].args['table'].name}::{start_node.args['this'].args['this'].name}" if len(predicate_literals) > 0: if all(isinstance(x, bool) for x in predicate_literals): output_type = "boolean"