Skip to content

Commit

Permalink
fix cos (#598)
Browse files Browse the repository at this point in the history
* move test

* fix cos delay

* move test

* Update http.py

* update version
  • Loading branch information
SAKURA-CAT authored May 31, 2024
1 parent 8409951 commit e157f3d
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 103 deletions.
12 changes: 10 additions & 2 deletions swanlab/api/cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
from qcloud_cos import CosS3Client
# noinspection PyPackageRequirements
from qcloud_cos.cos_threadpool import SimpleThreadPool
from datetime import datetime
from datetime import datetime, timedelta
from typing import List, Dict, Union


class CosClient:
REFRESH_TIME = 60 * 60 * 1.5 # 1.5小时

def __init__(self, data):
"""
初始化cos客户端
"""
self.__expired_time = datetime.fromtimestamp(data["expiredTime"])
self.__prefix = data["prefix"]
self.__bucket = data["bucket"]
Expand Down Expand Up @@ -66,4 +69,9 @@ def upload_files(self, keys: List[str], local_paths: List[str]) -> Dict[str, Uni

@property
def should_refresh(self):
return (self.__expired_time - datetime.utcnow()).seconds < self.REFRESH_TIME
# cos传递的是北京时间,需要添加8小时
now = datetime.utcnow() + timedelta(hours=8)
# 过期时间减去当前时间小于刷新时间,需要注意为负数的情况
if self.__expired_time < now:
return True
return (self.__expired_time - now).seconds < self.REFRESH_TIME
9 changes: 8 additions & 1 deletion swanlab/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def username(self):
"""
return self.__login_info.username

@property
def cos(self):
return self.__cos

@property
def proj_id(self):
return self.__proj.cuid
Expand Down Expand Up @@ -175,6 +179,7 @@ def upload_files(self, keys: list, local_paths: list) -> Dict[str, Union[bool, L
:return: 返回上传结果, 包含success_all和detail两个字段,detail为每一个文件的上传结果(通过index索引对应)
"""
if self.__cos.should_refresh:
swanlog.debug("Refresh cos...")
self.__get_cos()
keys = [key[1:] if key.startswith("/") else key for key in keys]
return self.__cos.upload_files(keys, local_paths)
Expand Down Expand Up @@ -216,9 +221,11 @@ def _():
先创建实验,后生成cos凭证
:return:
"""

data = self.post(
f"/project/{self.groupname}/{self.__proj.name}/runs",
{"name": exp_name, "colors": list(colors), "description": description},
{"name": exp_name, "colors": list(colors), "description": description} if description else {
"name": exp_name, "colors": list(colors)}
)
self.__exp = ExperimentInfo(data)
# 获取cos信息
Expand Down
2 changes: 1 addition & 1 deletion swanlab/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "swanlab",
"version": "0.3.6",
"version": "0.3.7",
"description": "",
"python": "true",
"host": {
Expand Down
File renamed without changes.
55 changes: 55 additions & 0 deletions test/unit/api/pytest_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/5/31 12:47
@File: pytest_cos.py
@IDE: pycharm
@Description:
测试http的api
开发环境下存储凭证过期时间为3s
"""
import os
import time
import nanoid
from swanlab.api.http import create_http, HTTP, CosClient
from swanlab.api.auth.login import login_by_key
from tutils import KEY, TEMP_PATH

alphabet = "abcdefghijklmnopqrstuvwxyz"


class TestCosSuite:
http: HTTP = None
project_name = nanoid.generate(alphabet)
experiment_name = nanoid.generate(alphabet)
file_path = os.path.join(TEMP_PATH, nanoid.generate(alphabet))
now_refresh_time = 1
pre_refresh_time = CosClient.REFRESH_TIME

@classmethod
def setup_class(cls):
CosClient.REFRESH_TIME = cls.now_refresh_time
# 这里不测试保存token的功能
login_info = login_by_key(KEY, save=False)
cls.http = create_http(login_info)
cls.http.mount_project(cls.project_name)
cls.http.mount_exp(cls.experiment_name, ('#ffffff', '#ffffff'))
# temp路径写一个文件上传
with open(cls.file_path, "w") as f:
f.write("test")

@classmethod
def teardown_class(cls):
CosClient.REFRESH_TIME = cls.pre_refresh_time

def test_cos_ok(self):
assert self.http is not None
assert self.http.cos is not None

def test_cos_upload(self):
self.http.upload("/key", self.file_path)
# 开发版本设置的过期时间为3s,等待过期
time.sleep(3)
# 重新上传,测试刷新
assert self.http.cos.should_refresh is True
self.http.upload("/key", self.file_path)
90 changes: 90 additions & 0 deletions test/unit/pytest_package.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from tutils.config import nanoid
import json
from swanlab.package import (
get_package_version,
get_host_web,
get_user_setting_path,
get_host_api,
get_project_url,
get_experiment_url
)
import os
from swanlab.package import get_package_latest_version

PACKAGE_PATH = os.environ["SWANLAB_PACKAGE_PATH"]

package_data = json.load(open(PACKAGE_PATH))


def test_package_latest_version():
"""
Expand All @@ -9,3 +24,78 @@ def test_package_latest_version():
assert isinstance(get_package_latest_version(), str)
# 超时情况
assert get_package_latest_version(timeout=1e-3) is None


def mock_get_package_version():
return get_package_version(PACKAGE_PATH)


def mock_get_host_web():
return get_host_web(PACKAGE_PATH)


def mock_get_host_api():
return get_host_api(PACKAGE_PATH)


def mock_get_user_setting_path():
return get_user_setting_path(PACKAGE_PATH)


def mock_get_project_url(username: str, projname: str):
return get_project_url(username, projname, PACKAGE_PATH)


def mock_get_experiment_url(username: str, projname: str, expid: str):
return get_experiment_url(username, projname, expid, PACKAGE_PATH)


# ---------------------------------- 简单测试一下 ----------------------------------


def test_get_package_version():
"""
测试获取版本号
"""
assert mock_get_package_version() == package_data["version"]


def test_get_host_web():
"""
测试获取web地址
"""
assert mock_get_host_web() == package_data["host"]["web"]


def test_get_host_api():
"""
测试获取api地址
"""
assert mock_get_host_api() == package_data["host"]["api"]


def test_get_user_setting_path():
"""
测试获取用户设置文件路径
"""
assert mock_get_user_setting_path() == mock_get_host_web() + "/settings"


def test_get_project_url():
"""
测试获取项目url
"""
username = nanoid.generate()
projname = nanoid.generate()
assert mock_get_project_url(username, projname) == mock_get_host_web() + "/" + username + "/" + projname


def test_get_experiment_url():
"""
测试获取实验url
"""
username = nanoid.generate()
projname = nanoid.generate()
expid = nanoid.generate()
assert (mock_get_experiment_url(username, projname, expid)
== mock_get_host_web() + "/" + username + "/" + projname + "/" + expid)
99 changes: 0 additions & 99 deletions test/unit/utils/pytest_package.py

This file was deleted.

0 comments on commit e157f3d

Please sign in to comment.