-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscript.py
41 lines (30 loc) · 1.38 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import re
import sys
def generate_impls(input_file, output_file):
with open(input_file, "r") as f:
content = f.read()
# Find the content within the specified mod
mod_pattern = r"pub mod dnnl_format_tag_t \{\s*((?:[^{}]|\{[^{}]*\})*)\s*\}"
mod_match = re.search(mod_pattern, content, re.DOTALL)
if not mod_match:
print(f"Error: Could not find mod dnnl_format_tag_t in the input file.")
return
mod_content = mod_match.group(1)
# Use a regular expression to find the enum variants and their doc comments within the mod
pattern = r"#\[doc = \"(.*?)\"\]\n\s*pub const (\w+):"
matches = re.findall(pattern, mod_content)
with open(output_file, "w") as f:
for comment, variant in matches:
tag_struct = variant.replace("dnnl_", "")
# Escape any backslashes in the comment
escaped_comment = comment.replace("\\", "\\\\")
# Remove newlines and extra whitespace from the comment
cleaned_comment = " ".join(escaped_comment.split())
num_pattern = r".*?([0-9]+)D.*"
ma = re.search(num_pattern, cleaned_comment)
num = int(ma.group(1)) if ma else 6
f.write(
f'impl_format_tag!({tag_struct}, {variant}, {num}, "{cleaned_comment}");\n'
)
if __name__ == "__main__":
generate_impls(sys.argv[1], "output.rs")