Skip to content

Commit

Permalink
fix launch bugs (#657)
Browse files Browse the repository at this point in the history
修复目前launch已知问题:

- [x] confirm输入n还是会继续执行
- [x] 拼接实验url时错误
- [x] 格式化日期时部分版本python报错
- [x] list的dashboard排版调整

此外,新增了一些测试函数,当launch相关接口响应3xx时,告诉用户需更新到最新版本
  • Loading branch information
SAKURA-CAT authored Jul 27, 2024
1 parent 304bc54 commit 4c36010
Show file tree
Hide file tree
Showing 13 changed files with 261 additions and 41 deletions.
3 changes: 2 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ torch
torchvision
python-dotenv
freezegun
build
build
requests-mock
8 changes: 8 additions & 0 deletions swanlab/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ def get_http() -> HTTP:
return http


def reset_http():
"""
重置http对象
"""
global http
http = None


def sync_error_handler(func):
"""
在一些接口中我们不希望线程奔溃,而是返回一个错误对象
Expand Down
21 changes: 8 additions & 13 deletions swanlab/cli/commands/task/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
打包、上传、开启任务
"""
import click
from .utils import login_init_sid
from swanlab.api import get_http
from .utils import login_init_sid, UseTaskHttp
# noinspection PyPackageRequirements
from qcloud_cos import CosConfig, CosS3Client
from swanlab.log import swanlog
Expand Down Expand Up @@ -84,7 +83,9 @@ def launch(path: str, entry: str, python: str, name: str):
text = f"The target folder {FONT.yellow(path)} will be packaged and uploaded, "
text += f"and you have specified {FONT.yellow(entry)} as the task entry point. "
swanlog.info(text)
click.confirm(FONT.swanlab("Do you wish to proceed?"))
ok = click.confirm(FONT.swanlab("Do you wish to proceed?"), abort=False)
if not ok:
return
# 压缩文件夹
memory_file = zip_folder(path)
# 上传文件
Expand All @@ -95,14 +96,6 @@ def launch(path: str, entry: str, python: str, name: str):
swanlog.info(f"Task launched successfully. You can use {FONT.yellow('swanlab task list')} to view the task.")


def fmt_entry(entry: str) -> str:
"""
格式化入口文件路径
:param entry:
:return:
"""


def zip_folder(dirpath: str) -> io.BytesIO:
"""
压缩文件夹
Expand Down Expand Up @@ -204,7 +197,8 @@ def upload_memory_file(memory_file: io.BytesIO) -> str:
上传内存文件
:returns 上传成功后的文件路径
"""
sts = get_http().get("/user/codes/sts")
with UseTaskHttp() as http:
sts = http.get("/user/codes/sts")
cos = CosClientForTask(sts)
val = memory_file.getvalue()
progress = TaskProgressBar(len(val))
Expand Down Expand Up @@ -247,4 +241,5 @@ def create(self):
"""
创建任务
"""
get_http().post("/task", self.__dict__())
with UseTaskHttp() as http:
http.post("/task", self.__dict__())
14 changes: 6 additions & 8 deletions swanlab/cli/commands/task/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
import time
import click
from typing import List
from .utils import login_init_sid
from rich.layout import Layout
from datetime import datetime
from rich.panel import Panel
from rich.table import Table
from rich.live import Live
from swanlab.api import get_http
from .utils import TaskModel
from .utils import TaskModel, UseTaskHttp, login_init_sid


@click.command()
Expand Down Expand Up @@ -45,13 +43,13 @@ def __init__(self, num: int, username: str):
"""
self.num = num
self.username = username
self.http = get_http()

def __dict__(self):
return {"num": self.num}

def list(self) -> List[TaskModel]:
tasks = self.http.get("/task", self.__dict__())
with UseTaskHttp() as http:
tasks = http.get("/task", self.__dict__())
return [TaskModel(self.username, task) for task in tasks]

def table(self):
Expand All @@ -65,7 +63,7 @@ def table(self):
st.add_column("Task ID", justify="right")
st.add_column("Task Name", justify="center")
st.add_column("Status", justify="center")
st.add_column("URL", justify="center")
st.add_column("URL", justify="center", no_wrap=True)
st.add_column("Started Time", justify="center")
st.add_column("Finished Time", justify="center")
for tlm in self.list():
Expand Down Expand Up @@ -117,8 +115,8 @@ def __init__(self, ltm: ListTasksModel):
Layout(name="main")
)
self.layout["main"].split_row(
Layout(name="task_table", ratio=5),
Layout(name="term_output", ratio=2, )
Layout(name="task_table", ratio=4),
Layout(name="term_output", ratio=1)
)
self.layout["header"].update(ListTaskHeader())
self.layout["task_table"].update(Panel(ltm.table(), border_style="magenta"))
Expand Down
18 changes: 10 additions & 8 deletions swanlab/cli/commands/task/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
根据cuid获取任务详情
"""
import click
from swanlab.api import get_http
from .utils import TaskModel, login_init_sid
from rich.syntax import Syntax, Console
import json
from .utils import TaskModel, login_init_sid, UseTaskHttp
from rich.syntax import Console, Syntax


def validate_six_char_string(_, __, value):
Expand All @@ -31,14 +29,15 @@ def search(cuid):
Get task detail by cuid
"""
login_info = login_init_sid()
http = get_http()
data = http.get(f"/task/{cuid}")
with UseTaskHttp() as http:
data = http.get(f"/task/{cuid}")
tm = TaskModel(login_info.username, data)
"""
任务名称,python版本,入口文件,任务状态,URL,创建时间,执行时间,结束时间,错误信息
"""
console = Console()
console.print("\n[bold]Task Info[/bold]")
print("")
console.print("[bold]Task Info[/bold]")
console.print(f"[bold]Task Name:[/bold] [yellow]{tm.name}[/yellow]")
console.print(f"[bold]Python Version:[/bold] [white]{tm.python}[white]")
console.print(f"[bold]Entry File:[/bold] [white]{tm.index}[white]")
Expand All @@ -52,4 +51,7 @@ def search(cuid):
console.print(f"[bold]Created At:[/bold] {tm.created_at}")
tm.started_at is not None and console.print(f"[bold]Started At:[/bold] {tm.started_at}")
tm.finished_at is not None and console.print(f"[bold]Finished At:[/bold] {tm.finished_at}")
tm.status == 'CRASHED' and console.print(f"[bold][red]Task Error[/red]:[/bold] \n\n{tm.msg}\n")
if tm.status == 'CRASHED':
console.print(f"[bold][red]Task Error[/red]:[/bold]\n")
console.print(Syntax(tm.msg, 'python', background_color="default"))
print("") # 加一行空行,与开头一致
29 changes: 27 additions & 2 deletions swanlab/cli/commands/task/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
任务相关工具函数
"""
from swanlab.package import get_key, get_experiment_url
from swanlab.api import terminal_login, create_http, LoginInfo
from swanlab.error import KeyFileError
from swanlab.api import terminal_login, create_http, LoginInfo, get_http
from swanlab.error import KeyFileError, ApiError
from datetime import datetime
from typing import Optional
from swanlab.log import swanlog
import sys


def login_init_sid() -> LoginInfo:
Expand Down Expand Up @@ -68,4 +71,26 @@ def url(self):
def fmt_time(date: str = None):
if date is None:
return None
date = date.replace("Z", "+00:00")
return datetime.fromisoformat(date).strftime("%Y-%m-%d %H:%M:%S")


class UseTaskHttp:
"""
主要用于检测http响应是否为3xx字段,如果是则要求用户更新版本
使用此类之前需要先调用login_init_sid()函数完成全局http对象的初始化
"""

def __init__(self):
self.http = get_http()

def __enter__(self):
return self.http

def __exit__(self, exc_type, exc_val: Optional[ApiError], exc_tb):
if exc_type is ApiError:
# api已过期,需要更新swanlab版本
if exc_val.resp.status_code // 100 == 3:
swanlog.info("SwanLab in your environment is outdated. Upgrade: `pip install -U swanlab`")
sys.exit(3)
return False
13 changes: 9 additions & 4 deletions swanlab/data/callback_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
from swanlab.log import swanlog
from swanlab.api import get_http
from swanlab.env import in_jupyter, SwanLabEnv
from swanlab.package import get_host_web, get_key
from swanlab.error import KeyFileError
from .callback_local import LocalRunCallback, get_run, SwanLabRunState
from swanlab.data.cloud import ThreadPool
from swanlab.package import get_package_version, get_package_latest_version
from swanlab.package import (
get_package_version,
get_package_latest_version,
get_experiment_url,
get_project_url,
get_key
)
from swankit.log import FONT
from swankit.env import create_time
import json
Expand Down Expand Up @@ -157,8 +162,8 @@ def _get_package_latest_version():
def _view_web_print(self):
self._watch_tip_print()
http = get_http()
project_url = get_host_web() + f"/@{http.groupname}/{http.projname}"
experiment_url = project_url + f"/runs/{http.exp_id}"
project_url = get_project_url(http.groupname, http.projname)
experiment_url = get_experiment_url(http.groupname, http.projname, http.exp_id)
swanlog.info("🏠 View project at " + FONT.blue(FONT.underline(project_url)))
swanlog.info("🚀 View run at " + FONT.blue(FONT.underline(experiment_url)))
return experiment_url
Expand Down
2 changes: 1 addition & 1 deletion swanlab/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "swanlab",
"version": "0.3.15-alpha.2",
"version": "0.3.15-alpha.3",
"description": "",
"python": "true"
}
4 changes: 2 additions & 2 deletions swanlab/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_project_url(username: str, projname: str) -> str:
:param projname: 项目名
:return: 项目的url
"""
return get_host_web() + "/" + username + "/" + projname
return get_host_web() + "/@" + username + "/" + projname


def get_experiment_url(username: str, projname: str, expid: str) -> str:
Expand All @@ -88,7 +88,7 @@ def get_experiment_url(username: str, projname: str, expid: str) -> str:
:param expid: 实验id
:return: 实验的url
"""
return get_project_url(username, projname) + "/" + expid
return get_project_url(username, projname) + "/runs/" + expid


# ---------------------------------- 登录相关 ----------------------------------
Expand Down
42 changes: 42 additions & 0 deletions test/unit/_/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/7/27 15:07
@File: setup.py
@IDE: pycharm
@Description:
测试tutils/setup.py
"""
import pytest
import tutils.setup as SU


def test_mock_login_info():
login_info = SU.mock_login_info()
assert login_info.is_fail is False
login_info = SU.mock_login_info(error_reason="Unauthorized")
assert login_info.is_fail is True
login_info = SU.mock_login_info(error_reason="Authorization Required")
assert login_info.is_fail is True
login_info = SU.mock_login_info(error_reason="Forbidden")
assert login_info.is_fail is True
login_info = SU.mock_login_info(error_reason="OK")
assert login_info.is_fail is False


def test_use_setup_http():
from swanlab.api import get_http
with SU.UseSetupHttp() as http:
assert http is not None
assert get_http() is not None
with pytest.raises(ValueError):
get_http()


def test_use_mocker():
with SU.UseMocker() as m:
m.post("/tmp", text="mock")
import requests
from swanlab.package import get_host_api
resp = requests.post(get_host_api() + "/tmp")
assert resp.text == "mock"
31 changes: 31 additions & 0 deletions test/unit/cli/test_cli_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/7/27 14:51
@File: test_cli_task.py
@IDE: pycharm
@Description:
测试cli/task
"""
import pytest
from swanlab.cli.commands.task.utils import UseTaskHttp
import tutils.setup as SU


def test_use_task_http_ok():
with SU.UseMocker() as m:
m.post("/test", text="mock")
with SU.UseSetupHttp():
with UseTaskHttp() as http:
text = http.post("/test")
assert text == "mock"


def test_use_task_http_abandon():
with pytest.raises(SystemExit) as p:
with SU.UseMocker() as m:
m.post("/test", status_code=301, reason="Abandon")
with SU.UseSetupHttp():
with UseTaskHttp() as http:
http.post("/test")
assert p.value.code == 3
4 changes: 2 additions & 2 deletions test/unit/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_get_project_url():
"""
username = nanoid.generate()
projname = nanoid.generate()
assert P.get_project_url(username, projname) == P.get_host_web() + "/" + username + "/" + projname
assert P.get_project_url(username, projname) == P.get_host_web() + "/@" + username + "/" + projname


def test_get_experiment_url():
Expand All @@ -69,7 +69,7 @@ def test_get_experiment_url():
assert P.get_experiment_url(
username, projname,
expid
) == P.get_host_web() + "/" + username + "/" + projname + "/" + expid
) == P.get_host_web() + "/@" + username + "/" + projname + "/runs/" + expid


# ---------------------------------- 登录部分 ----------------------------------
Expand Down
Loading

0 comments on commit 4c36010

Please sign in to comment.