forked from huggingface/transformers
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a script to check inits are consistent (huggingface#11024)
- Loading branch information
Showing
7 changed files
with
237 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# coding=utf-8 | ||
# Copyright 2020 The HuggingFace Inc. team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import re | ||
|
||
|
||
PATH_TO_TRANSFORMERS = "src/transformers" | ||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"] | ||
|
||
# Catches a line with a key-values pattern: "bla": ["foo", "bar"] | ||
_re_import_struct_key_value = re.compile(r'\s+"\S*":\s+\[([^\]]*)\]') | ||
# Catches a line if is_foo_available | ||
_re_test_backend = re.compile(r"^\s*if\s+is\_([a-z]*)\_available\(\):\s*$") | ||
# Catches a line _import_struct["bla"].append("foo") | ||
_re_import_struct_add_one = re.compile(r'^\s*_import_structure\["\S*"\]\.append\("(\S*)"\)') | ||
# Catches a line _import_struct["bla"].extend(["foo", "bar"]) or _import_struct["bla"] = ["foo", "bar"] | ||
_re_import_struct_add_many = re.compile(r"^\s*_import_structure\[\S*\](?:\.extend\(|\s*=\s+)\[([^\]]*)\]") | ||
# Catches a line with an object between quotes and a comma: "MyModel", | ||
_re_quote_object = re.compile('^\s+"([^"]+)",') | ||
# Catches a line with objects between brackets only: ["foo", "bar"], | ||
_re_between_brackets = re.compile("^\s+\[([^\]]+)\]") | ||
# Catches a line with from foo import bar, bla, boo | ||
_re_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") | ||
|
||
|
||
def parse_init(init_file): | ||
""" | ||
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects | ||
defined | ||
""" | ||
with open(init_file, "r", encoding="utf-8", newline="\n") as f: | ||
lines = f.readlines() | ||
|
||
line_index = 0 | ||
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"): | ||
line_index += 1 | ||
|
||
# If this is a traditional init, just return. | ||
if line_index >= len(lines): | ||
return None | ||
|
||
# First grab the objects without a specific backend in _import_structure | ||
objects = [] | ||
while not lines[line_index].startswith("if TYPE_CHECKING") and _re_test_backend.search(lines[line_index]) is None: | ||
line = lines[line_index] | ||
single_line_import_search = _re_import_struct_key_value.search(line) | ||
if single_line_import_search is not None: | ||
imports = [obj[1:-1] for obj in single_line_import_search.groups()[0].split(", ") if len(obj) > 0] | ||
objects.extend(imports) | ||
elif line.startswith(" " * 8 + '"'): | ||
objects.append(line[9:-3]) | ||
line_index += 1 | ||
|
||
import_dict_objects = {"none": objects} | ||
# Let's continue with backend-specific objects in _import_structure | ||
while not lines[line_index].startswith("if TYPE_CHECKING"): | ||
# If the line is an if is_backend_available, we grab all objects associated. | ||
if _re_test_backend.search(lines[line_index]) is not None: | ||
backend = _re_test_backend.search(lines[line_index]).groups()[0] | ||
line_index += 1 | ||
|
||
# Ignore if backend isn't tracked for dummies. | ||
if backend not in BACKENDS: | ||
continue | ||
|
||
objects = [] | ||
# Until we unindent, add backend objects to the list | ||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 4): | ||
line = lines[line_index] | ||
if _re_import_struct_add_one.search(line) is not None: | ||
objects.append(_re_import_struct_add_one.search(line).groups()[0]) | ||
elif _re_import_struct_add_many.search(line) is not None: | ||
imports = _re_import_struct_add_many.search(line).groups()[0].split(", ") | ||
imports = [obj[1:-1] for obj in imports if len(obj) > 0] | ||
objects.extend(imports) | ||
elif _re_between_brackets.search(line) is not None: | ||
imports = _re_between_brackets.search(line).groups()[0].split(", ") | ||
imports = [obj[1:-1] for obj in imports if len(obj) > 0] | ||
objects.extend(imports) | ||
elif _re_quote_object.search(line) is not None: | ||
objects.append(_re_quote_object.search(line).groups()[0]) | ||
elif line.startswith(" " * 8 + '"'): | ||
objects.append(line[9:-3]) | ||
elif line.startswith(" " * 12 + '"'): | ||
objects.append(line[13:-3]) | ||
line_index += 1 | ||
|
||
import_dict_objects[backend] = objects | ||
else: | ||
line_index += 1 | ||
|
||
# At this stage we are in the TYPE_CHECKING part, first grab the objects without a specific backend | ||
objects = [] | ||
while ( | ||
line_index < len(lines) | ||
and _re_test_backend.search(lines[line_index]) is None | ||
and not lines[line_index].startswith("else") | ||
): | ||
line = lines[line_index] | ||
single_line_import_search = _re_import.search(line) | ||
if single_line_import_search is not None: | ||
objects.extend(single_line_import_search.groups()[0].split(", ")) | ||
elif line.startswith(" " * 8): | ||
objects.append(line[8:-2]) | ||
line_index += 1 | ||
|
||
type_hint_objects = {"none": objects} | ||
# Let's continue with backend-specific objects | ||
while line_index < len(lines): | ||
# If the line is an if is_backemd_available, we grab all objects associated. | ||
if _re_test_backend.search(lines[line_index]) is not None: | ||
backend = _re_test_backend.search(lines[line_index]).groups()[0] | ||
line_index += 1 | ||
|
||
# Ignore if backend isn't tracked for dummies. | ||
if backend not in BACKENDS: | ||
continue | ||
|
||
objects = [] | ||
# Until we unindent, add backend objects to the list | ||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): | ||
line = lines[line_index] | ||
single_line_import_search = _re_import.search(line) | ||
if single_line_import_search is not None: | ||
objects.extend(single_line_import_search.groups()[0].split(", ")) | ||
elif line.startswith(" " * 12): | ||
objects.append(line[12:-2]) | ||
line_index += 1 | ||
|
||
type_hint_objects[backend] = objects | ||
else: | ||
line_index += 1 | ||
|
||
return import_dict_objects, type_hint_objects | ||
|
||
|
||
def analyze_results(import_dict_objects, type_hint_objects): | ||
""" | ||
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init. | ||
""" | ||
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()): | ||
return ["Both sides of the init do not have the same backends!"] | ||
|
||
errors = [] | ||
for key in import_dict_objects.keys(): | ||
if sorted(import_dict_objects[key]) != sorted(type_hint_objects[key]): | ||
name = "base imports" if key == "none" else f"{key} backend" | ||
errors.append(f"Differences for {name}:") | ||
for a in type_hint_objects[key]: | ||
if a not in import_dict_objects[key]: | ||
errors.append(f" {a} in TYPE_HINT but not in _import_structure.") | ||
for a in import_dict_objects[key]: | ||
if a not in type_hint_objects[key]: | ||
errors.append(f" {a} in _import_structure but not in TYPE_HINT.") | ||
return errors | ||
|
||
|
||
def check_all_inits(): | ||
""" | ||
Check all inits in the transformers repo and raise an error if at least one does not define the same objects in | ||
both halves. | ||
""" | ||
failures = [] | ||
for root, _, files in os.walk(PATH_TO_TRANSFORMERS): | ||
if "__init__.py" in files: | ||
fname = os.path.join(root, "__init__.py") | ||
objects = parse_init(fname) | ||
if objects is not None: | ||
errors = analyze_results(*objects) | ||
if len(errors) > 0: | ||
errors[0] = f"Problem in {fname}, both halves do not define the same objects.\n{errors[0]}" | ||
failures.append("\n".join(errors)) | ||
if len(failures) > 0: | ||
raise ValueError("\n\n".join(failures)) | ||
|
||
|
||
if __name__ == "__main__": | ||
check_all_inits() |