From 84e68676008334ac217ef1f5f243af715bbc03f1 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+vedpatwardhan@users.noreply.github.com> Date: Wed, 20 Sep 2023 19:09:56 +0530 Subject: [PATCH] fix: updated test_dependencies.py to correctly test requirements (#25698) For supporting ==, <=, <, >= or > as it broke codespace creation with the change to urllib<2.0 in the requirements.txt --- run_tests_CLI/test_dependencies.py | 33 ++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/run_tests_CLI/test_dependencies.py b/run_tests_CLI/test_dependencies.py index 348a63be15e6c..02bb80e9c2d0c 100644 --- a/run_tests_CLI/test_dependencies.py +++ b/run_tests_CLI/test_dependencies.py @@ -6,6 +6,7 @@ import termcolor import importlib import faulthandler +from packaging import version faulthandler.enable() ERROR = False @@ -17,15 +18,33 @@ def parse(str_in): str_in = str_in.replace("\n", "") + import_ops = ["==", "<", "<=", ">", ">="] if "mod_name=" in str_in: mod_name = str_in.split("mod_name=")[-1].split(" ")[0].split(",")[0] else: mod_name = str_in.split("=")[0].split(" ")[0] - if "==" in str_in: - version = str_in.split("==")[-1].split(" ")[0].split(",")[0] + expected_version, expected_op = None, None + for import_op in import_ops: + if import_op in str_in: + mod_name, expected_version = str_in.split(import_op) + expected_version = expected_version.split(" ")[0].split(",")[0] + expected_op = import_op + return mod_name, expected_version, expected_op + + +def compare(version1, version2, operator): + version1 = version.parse(version1) + version2 = version.parse(version2) + if operator == "==": + return version1 == version2 + elif "<" in operator: + if operator == "<=": + return version1 <= version2 + return version1 < version2 else: - version = None - return mod_name, version + if operator == ">=": + return version1 >= version2 + return version1 > version2 def test_imports(fname, assert_version, update_versions): @@ -42,7 +61,9 @@ def test_imports(fname, assert_version, update_versions): with open(fname, "r") as f: file_lines = f.readlines() mod_names_n_versions = [parse(req) for req in file_lines] - for line_num, (mod_name, expected_version) in enumerate(mod_names_n_versions): + for line_num, (mod_name, expected_version, expected_op) in enumerate( + mod_names_n_versions + ): # noinspection PyBroadException try: mod = importlib.import_module(mod_name) @@ -64,7 +85,7 @@ def test_imports(fname, assert_version, update_versions): except Exception: detected_version = None if detected_version and expected_version: - if detected_version == expected_version: + if compare(detected_version, expected_version, expected_op): msg = f"{mod_name} detected correct version: {detected_version}\n" else: msg = (