forked from princeton-vl/CoqGym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgallina.py
106 lines (86 loc) · 3.16 KB
/
gallina.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Utilities for reconstructing Gallina terms from their serialized S-expressions in CoqGym
from io import StringIO
from vernac_types import Constr__constr
from lark import Lark, Transformer, Visitor, Discard
from lark.lexer import Token
from lark.tree import Tree
from lark.tree import pydot__tree_to_png
import logging
logging.basicConfig(level=logging.DEBUG)
from collections import defaultdict
import re
import pdb
def traverse_postorder(node, callback):
for c in node.children:
if isinstance(c, Tree):
traverse_postorder(c, callback)
callback(node)
class GallinaTermParser:
def __init__(self, caching=True):
self.caching = caching
t = Constr__constr()
self.grammar = (
t.to_ebnf(recursive=True)
+ """
%import common.STRING_INNER
%import common.ESCAPED_STRING
%import common.SIGNED_INT
%import common.WS
%ignore WS
"""
)
self.parser = Lark(
StringIO(self.grammar), start="constr__constr", parser="lalr"
)
if caching:
self.cache = {}
def parse_no_cache(self, term_str):
ast = self.parser.parse(term_str)
ast.quantified_idents = set()
def get_quantified_idents(node):
if (
node.data == "constructor_prod"
and node.children != []
and node.children[0].data == "constructor_name"
):
ident = node.children[0].children[0].value
if ident.startswith('"') and ident.endswith('"'):
ident = ident[1:-1]
ast.quantified_idents.add(ident)
traverse_postorder(ast, get_quantified_idents)
ast.quantified_idents = list(ast.quantified_idents)
def compute_height_remove_toekn(node):
children = []
node.height = 0
for c in node.children:
if isinstance(c, Tree):
node.height = max(node.height, c.height + 1)
children.append(c)
node.children = children
traverse_postorder(ast, compute_height_remove_toekn)
return ast
def parse(self, term_str):
if self.caching:
if term_str not in self.cache:
self.cache[term_str] = self.parse_no_cache(term_str)
return self.cache[term_str]
else:
return self.parse_no_cache(term_str)
def print_grammar(self):
print(self.grammar)
class Counter(Visitor):
def __init__(self):
super().__init__()
self.counts_nonterminal = defaultdict(int)
self.counts_terminal = defaultdict(int)
def __default__(self, tree):
self.counts_nonterminal[tree.data] += 1
for c in tree.children:
if isinstance(c, Token):
self.counts_terminal[c.value] += 1
class TreeHeight(Transformer):
def __default__(self, symbol, children, meta):
return 1 + max([0 if isinstance(c, Token) else c for c in children] + [-1])
class TreeNumTokens(Transformer):
def __default__(self, symbol, children, meta):
return sum([1 if isinstance(c, Token) else c for c in children])