Skip to content

Commit

Permalink
fix: updated test_dependencies.py to correctly test requirements (#25698
Browse files Browse the repository at this point in the history
)

For supporting ==, <=, <, >= or > as it broke codespace creation with the change to urllib<2.0 in the requirements.txt
  • Loading branch information
vedpatwardhan authored Sep 20, 2023
1 parent 6f56ff0 commit 84e6867
Showing 1 changed file with 27 additions and 6 deletions.
33 changes: 27 additions & 6 deletions run_tests_CLI/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import termcolor
import importlib
import faulthandler
from packaging import version

faulthandler.enable()
ERROR = False
Expand All @@ -17,15 +18,33 @@

def parse(str_in):
str_in = str_in.replace("\n", "")
import_ops = ["==", "<", "<=", ">", ">="]
if "mod_name=" in str_in:
mod_name = str_in.split("mod_name=")[-1].split(" ")[0].split(",")[0]
else:
mod_name = str_in.split("=")[0].split(" ")[0]
if "==" in str_in:
version = str_in.split("==")[-1].split(" ")[0].split(",")[0]
expected_version, expected_op = None, None
for import_op in import_ops:
if import_op in str_in:
mod_name, expected_version = str_in.split(import_op)
expected_version = expected_version.split(" ")[0].split(",")[0]
expected_op = import_op
return mod_name, expected_version, expected_op


def compare(version1, version2, operator):
version1 = version.parse(version1)
version2 = version.parse(version2)
if operator == "==":
return version1 == version2
elif "<" in operator:
if operator == "<=":
return version1 <= version2
return version1 < version2
else:
version = None
return mod_name, version
if operator == ">=":
return version1 >= version2
return version1 > version2


def test_imports(fname, assert_version, update_versions):
Expand All @@ -42,7 +61,9 @@ def test_imports(fname, assert_version, update_versions):
with open(fname, "r") as f:
file_lines = f.readlines()
mod_names_n_versions = [parse(req) for req in file_lines]
for line_num, (mod_name, expected_version) in enumerate(mod_names_n_versions):
for line_num, (mod_name, expected_version, expected_op) in enumerate(
mod_names_n_versions
):
# noinspection PyBroadException
try:
mod = importlib.import_module(mod_name)
Expand All @@ -64,7 +85,7 @@ def test_imports(fname, assert_version, update_versions):
except Exception:
detected_version = None
if detected_version and expected_version:
if detected_version == expected_version:
if compare(detected_version, expected_version, expected_op):
msg = f"{mod_name} detected correct version: {detected_version}\n"
else:
msg = (
Expand Down

0 comments on commit 84e6867

Please sign in to comment.