diff --git a/.vscode/launch.json b/.vscode/launch.json index edc067f8c..32753039a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -37,19 +37,26 @@ "type": "debugpy", "request": "launch", "module": "pytest", - "args": [ - "${file}" - ], + "args": ["${file}"], "console": "integratedTerminal" }, + { + "name": "(跳过云)进行所有单元测试", + "type": "debugpy", + "request": "launch", + "module": "pytest", + "args": ["test/unit"], + "console": "integratedTerminal", + "env": { + "TEST_CLOUD_SKIP": "true" + } + }, { "name": "进行所有单元测试", "type": "debugpy", "request": "launch", "module": "pytest", - "args": [ - "test/unit" - ], + "args": ["test/unit"], "console": "integratedTerminal" }, // 打包命令 diff --git a/swanlab/data/formater.py b/swanlab/data/formater.py index 3eab49df8..d8881d1dc 100644 --- a/swanlab/data/formater.py +++ b/swanlab/data/formater.py @@ -11,6 +11,7 @@ import re import json import yaml +from typing import List def check_string(target: str) -> bool: @@ -159,5 +160,70 @@ def check_key_format(key: str, auto_cut=True) -> str: raise ValueError(f"tag: {key} is an empty string") if key.startswith((".", "/")): raise ValueError(f"tag: {key} can't start with '.' or '/' and blank space") + if key.endswith((".", "/")): # cannot create folder end with '.' or '/' + raise ValueError(f"tag: {key} can't end with '.' or '/' and blank space") # 检查长度 return _auto_cut("tag", key, max_len, auto_cut) + + +def check_unique_on_case_insensitive(names: List[str]): + """ + Ensure that the names are unique in case-insensitive. + + Parameters + ---------- + names : List[str] + List of names. + + Returns + ------- + bool + True if names are unique + + Raises + ------ + ValueError + names are not unique + """ + exist_names_set = set() + for n in names: + n_lower = n.lower() + if n_lower in exist_names_set: + raise ValueError(f'tag: Windows is case insensitive, find same name: "{n}"') + exist_names_set.add(n_lower) + + return True + + +def check_win_reserved_folder_name(folder_name: str, auto_fix=True) -> str: + """ + Check if a folder name is reserved or not support to Windows. + see detail: https://learn.microsoft.com/zh-cn/biztalk/core/restrictions-when-configuring-the-file-adapter + + Parameters + ---------- + folder_name : str + Name of the folder to check. + auto_fix : bool, optional + auto fix unsupport folder_name, default True + If the value is False, try to throw an ValueError + + Returns + ------- + bool + return fix name + + Raises + ------ + ValueError + key not support to Windows. + """ + # Regular expression to match reserved names optionally followed by a dot (.) and any character + reserved_pattern = re.compile(r"^(CON|PRN|AUX|CLOCK\$|NUL|COM[1-9]|LPT[1-9])(\..*)?$", re.IGNORECASE) + + # Check if the cleaned folder name is in the reserved names list + if bool(reserved_pattern.match(folder_name)): + if not auto_fix: + raise ValueError(f"tag: {folder_name} is reserved names in windows") + folder_name = "_" + folder_name + return folder_name diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 439204c20..639f7e194 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -18,10 +18,10 @@ from datetime import datetime from typing import Callable, Optional, Dict from .operator import SwanLabRunOperator, RuntimeInfo -from ..formater import check_key_format +from ..formater import check_key_format, check_win_reserved_folder_name, check_unique_on_case_insensitive from swanlab.env import get_mode, get_swanlog_dir import random - +from swankit.env import is_windows class SwanLabRunState(Enum): """SwanLabRunState is an enumeration class that represents the state of the experiment. @@ -259,6 +259,11 @@ def log(self, data: dict, step: int = None): step : int, optional The step number of the current data, if not provided, it will be automatically incremented. If step is duplicated, the data will be ignored. + + Raise + ---------- + ValueError: + Unsupported key names. """ if self.__state != SwanLabRunState.RUNNING: raise RuntimeError("After experiment finished, you can no longer log data to the current experiment") @@ -278,6 +283,9 @@ def log(self, data: dict, step: int = None): step = None log_return = {} + if is_windows(): + # 在windows不区分大小写的情况下需要检查key是否大小写不敏感独一 + check_unique_on_case_insensitive(data.keys()) # 遍历data,记录data for k, v in data.items(): _k = k @@ -285,6 +293,21 @@ def log(self, data: dict, step: int = None): if k != _k: # 超过255字符,截断 swanlog.warning(f"Key {_k} is too long, cut to 255 characters.") + if k in data.keys(): + raise ValueError(f'tag: Unsupport too long Key "{_k}" and auto cut failed') + if is_windows(): + # windows 中要增加保留文件夹名的判断 + _k = k + k = check_win_reserved_folder_name(k) + if k != _k: + # key名为windows保留文件名 + lower_key_name = [k.lower() in data.keys()] + if k.lower() in lower_key_name: + # 修复后又和原先key重名(大小写不区分情况下) + raise ValueError( + f"Key {_k} unsupport on windows and auto fix it failed. Please change a key name instead" + ) + # swanlog.warning(f"Key {_k} unsupport on windows, auto used {k} instead") # todo: 每次都打很烦,暂时先注释掉,回头最好弄成只打一次warning # ---------------------------------- 包装数据 ---------------------------------- # 输入为可转换为float的数据类型 if isinstance(v, (int, float, FloatConvertible)): diff --git a/test/unit/data/pytest_fomater.py b/test/unit/data/pytest_fomater.py index da0d4460c..d583d5866 100644 --- a/test/unit/data/pytest_fomater.py +++ b/test/unit/data/pytest_fomater.py @@ -15,6 +15,8 @@ check_proj_name_format, _auto_cut, check_key_format, + check_unique_on_case_insensitive, + check_win_reserved_folder_name, ) @@ -113,8 +115,9 @@ def test_proj_name_no_cut(self, value: str): class TestTag: + @pytest.mark.parametrize( - "value", [generate(size=255), generate(size=100), generate(size=1), "12/", "-", "_", "👾👾👾👾👾👾"] + "value", [generate(size=255), generate(size=100), generate(size=1), "12", "-", "_", "👾👾👾👾👾👾"] ) def test_tag_common(self, value): """ @@ -147,7 +150,7 @@ def test_tag_type_error(self, value: str): with pytest.raises(TypeError): check_key_format(value) - @pytest.mark.parametrize("value", ["", " ", " " * 256, ".sas", "/asa"]) + @pytest.mark.parametrize("value", ["", " ", " " * 256, ".sas", "/asa", "abc/", "bac."]) def test_tag_value_error(self, value: str): """ 测试不合法值 @@ -183,3 +186,89 @@ def test_tag_no_cut(self, value: str): """ with pytest.raises(IndexError): check_key_format(value, auto_cut=False) + + +class TestWinTag: + __reserved_name__ = [ + "CON", + "PRN", + "AUX", + "CLOCK$", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + ] + + @pytest.mark.parametrize( + "value", + ["abc", "def", "ghi"], + ) + def test_normal_name_in_win(self, value: str): + """ + 测试正常用户能否通过 + """ + assert value == check_win_reserved_folder_name(value, auto_fix=True) + + @pytest.mark.parametrize( + "value", + __reserved_name__, + ) + def test_check_reserved_name(self, value: str): + """ + 测试是否能够检测出win保留名 + """ + with pytest.raises(ValueError): + check_win_reserved_folder_name(value, auto_fix=False) + + @pytest.mark.parametrize( + "value", + __reserved_name__, + ) + def test_fix_reserved_name(self, value: str): + """ + 测试是否能够自动修复保留名 + """ + assert "_" + value == check_win_reserved_folder_name(value, auto_fix=True) + + @pytest.mark.parametrize( + "list_value", + [ + ["ab", "cd", "ef"], + ["ghi", "JKL", "MN"], + ], + ) + def test_unique_name_list(self, list_value: str): + """ + 测试是否能够自动修复保留名 + """ + assert check_unique_on_case_insensitive(list_value) + + @pytest.mark.parametrize( + "list_value", + [ + ["Ab", "CD", "AB"], + ["ghi", "gHi", "Mn"], + ], + ) + def test_duplicate_name_list(self, list_value: str): + """ + 测试是否能够自动修复保留名 + """ + with pytest.raises(ValueError): + check_unique_on_case_insensitive(list_value)