diff --git a/training/benchmarks/driver/dist_pytorch.py b/training/benchmarks/driver/dist_pytorch.py index 6c824c422..2704dcfd5 100755 --- a/training/benchmarks/driver/dist_pytorch.py +++ b/training/benchmarks/driver/dist_pytorch.py @@ -149,6 +149,8 @@ def barrier(vendor="nvidia"): if torch.distributed.is_available() and torch.distributed.is_initialized(): if vendor == "kunlunxin": torch.distributed.barrier() + elif vendor == "mthreads": + torch.distributed.barrier() else: torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) torch.cuda.synchronize() @@ -172,6 +174,23 @@ def init_dist_training_env(config): rank=rank, world_size=world_size) config.n_device = torch.distributed.get_world_size() + elif config.vendor == "mthreads": + import torch_musa + if int(os.environ.get("WORLD_SIZE", 1)) <= 1: + config.device = torch.device("musa") + config.n_device = 1 + else: + torch.musa.set_device(config.local_rank) + host_addr_full = 'tcp://' + os.environ[ + "MASTER_ADDR"] + ':' + os.environ["MASTER_PORT"] + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + torch.distributed.init_process_group(backend=config.dist_backend, + init_method=host_addr_full, + rank=rank, + world_size=world_size) + config.device = torch.device("musa", config.local_rank) + config.n_device = torch.distributed.get_world_size() else: # nvidia if int(os.environ.get("WORLD_SIZE", 1)) <= 1: config.device = torch.device("cuda") diff --git a/training/benchmarks/driver/helper.py b/training/benchmarks/driver/helper.py index c8f406615..de513901e 100644 --- a/training/benchmarks/driver/helper.py +++ b/training/benchmarks/driver/helper.py @@ -74,6 +74,12 @@ def set_seed(self, seed: int, vendor: str = None): elif lower_vendor == "ascend": import mindspore mindspore.set_seed(seed) + elif lower_vendor == "mthreads": + import torch + import torch_musa + torch.manual_seed(seed) + torch.musa.manual_seed(seed) + torch.musa.manual_seed_all(seed) else: # TODO 其他厂商设置seed,在此扩展 pass diff --git a/training/benchmarks/resnet50/pytorch/train/trainer.py b/training/benchmarks/resnet50/pytorch/train/trainer.py index 52e7d6ae7..b07d90c68 100755 --- a/training/benchmarks/resnet50/pytorch/train/trainer.py +++ b/training/benchmarks/resnet50/pytorch/train/trainer.py @@ -82,22 +82,7 @@ def train_one_epoch(self, train_dataloader, eval_dataloader): pure_start_time = time.time() optimizer.zero_grad() - images, target = batch - if scaler is not None: - with torch.cuda.amp.autocast(enabled=True): - output = model(images) - loss = criterion(output, target) - - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - else: - output = model(images) - - criterion = torch.nn.CrossEntropyLoss() - loss = criterion(output, target) - loss.backward() - optimizer.step() + loss = self.adapter.train_step(model, batch, optimizer, scaler) if step % self.config.log_freq == 0: print("Train Step " + str(step) + "/" + str(len(data_loader)) + diff --git a/training/benchmarks/resnet50/pytorch/train/trainer_adapter.py b/training/benchmarks/resnet50/pytorch/train/trainer_adapter.py index ba8eaa585..d4b7b4708 100755 --- a/training/benchmarks/resnet50/pytorch/train/trainer_adapter.py +++ b/training/benchmarks/resnet50/pytorch/train/trainer_adapter.py @@ -41,3 +41,23 @@ def create_grad_scaler(): """create_grad_scaler for mixed precision training""" scaler = torch.cuda.amp.GradScaler() if config.amp else None return scaler + + +def train_step(model, batch, optimizer, scaler=None): + """train one step""" + images, target = batch + criterion = torch.nn.CrossEntropyLoss() + if scaler: + with torch.cuda.amp.autocast(enabled=True): + output = model(images) + loss = criterion(output, target) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + output = model(images) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + return loss diff --git a/training/mthreads/README.md b/training/mthreads/README.md new file mode 100644 index 000000000..194b9e73f --- /dev/null +++ b/training/mthreads/README.md @@ -0,0 +1,70 @@ + +# 厂商信息 + +官网: https://www.mthreads.com/ + +摩尔线程智能科技(北京)有限责任公司(简称:摩尔线程)是一家以GPU芯片设计为主的集成电路设计企业,专注于研发设计全功能GPU芯片及相关产品,为科技生态合作伙伴提供强大的计算加速能力。公司致力于创新研发面向“元计算”应用的新一代GPU,构建融合视觉计算、3D图形计算、科学计算及人工智能计算的综合计算平台,建立基于云原生GPU计算的生态系统,助力驱动数字经济发展。 + +摩尔线程MTT S系列全功能GPU支持多样算力,借助覆盖深度学习、图形渲染、视频处理和科学计算的完整MUSA软件栈,可为AI训练、AI推理、大模型、AIGC、云游戏、云渲染、视频云、数字孪生等场景提供通用智能算力支持,旨在为数据中心、智算中心和元计算中心的建设构建坚实算力基础,助力元宇宙中多元应用创新和落地。 + +MUSA软件栈通过musify CUDA代码迁移工具、计算/通信加速库、mcc编译器、musa运行时和驱动实现对CUDA生态的兼容,帮助用户快速完成代码及应用的迁移。通过torch_musa插件,可以实现MTT S系列GPU对原生PyTorch的对接,用户可以无感的把AI模型运行在摩尔线程全功能GPU上。 + +# FlagPerf适配验证环境说明 +## 环境配置参考 + - 硬件 + - 机器型号: MCCX D800 + - 加速卡型号: MTT S4000 48GB + - CPU型号:Intel(R) Xeon(R) Gold 6430 CPU @ 2.00GHz + - 多机网络类型、带宽: InfiniBand,2*200Gbps + - 软件 + - OS版本:Ubuntu 20.04 LTS + - OS kernel版本: 5.4.0-42-generic + - 加速卡驱动版本:2.2.0 + - Docker 版本: 20.10.24 + +## 容器镜像信息 +- 容器构建信息 + - Dockerfile路径:training/mthreads/docker_image/pytorch_2.0/Dockerfile + - 构建后软件安装脚本: training/mthreads/docker_image/pytorch_2.0/pytorch_2.0_install.sh + +- 核心软件信息 + + - AI框架&版本 + - PyTorch: v2.0.0 + + - 其它软件版本 + - torch_musa: 2.0.0+git8614ba1 + - musa toolkits: 1.5.0+git3d8791d + - mcc: 1.5.2+git3730bdd + - mublas: 1.2.0+gitd9867b5 + + +## 加速卡监控采集 +- 加速卡使用信息采集命令 + + ```bash + mthreads-gmi -q | grep -E 'GPU Current Temp|Power Draw|Used|Total|Gpu' | \ + awk -F ': *' '/GPU Current Temp|Power Draw|Used|Total|Gpu/ \ + { values[(NR-1)%5+1] = $2; } NR % 5 == 0 { print values[4], values[5], values[2], values[1], values[3]; }' + ``` +- 监控项示例: + ```bash + 45C 109.51W 1MiB 32768MiB 0% + 44C 108.95W 1MiB 32768MiB 0% + 46C 110.87W 1MiB 32768MiB 0% + 43C 104.33W 1MiB 32768MiB 0% + 44C 107.55W 8MiB 32768MiB 0% + 46C 110.51W 8MiB 32768MiB 0% + 44C 106.59W 8MiB 32768MiB 0% + 44C 104.58W 8MiB 32768MiB 0% + ``` +- 加速卡使用信息采集项说明 + +|监控项| 日志文件 | 格式 | +|---|---|---| +|温度| mthreads_monitor.log | xxx C | +|功耗 |mthreads_monitor.log | xxx W | +|显存占用大小 |mthreads_monitor.log |xxx MiB | +|总显存大小 |mthreads_monitor.log |xxx MiB | +|显存使用率 |mthreads_monitor.log |xxx % | + diff --git a/training/mthreads/docker_image/pytorch_2.0/Dockerfile b/training/mthreads/docker_image/pytorch_2.0/Dockerfile new file mode 100644 index 000000000..2982c1af5 --- /dev/null +++ b/training/mthreads/docker_image/pytorch_2.0/Dockerfile @@ -0,0 +1,3 @@ +FROM moore-threads/pytorch:flagperf-py38 +ENV PATH /opt/conda/envs/py38/bin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/musa/lib/:$LD_LIBRARY_PATH diff --git a/training/mthreads/docker_image/pytorch_2.0/pytorch_install.sh b/training/mthreads/docker_image/pytorch_2.0/pytorch_install.sh new file mode 100644 index 000000000..cc1f786e8 --- /dev/null +++ b/training/mthreads/docker_image/pytorch_2.0/pytorch_install.sh @@ -0,0 +1 @@ +#!/bin/bash \ No newline at end of file diff --git a/training/mthreads/mthreads_monitor.py b/training/mthreads/mthreads_monitor.py new file mode 100644 index 000000000..092b832df --- /dev/null +++ b/training/mthreads/mthreads_monitor.py @@ -0,0 +1,290 @@ +# !/usr/bin/env python3 +# encoding: utf-8 +''' +Usage: python3 sys-monitor.py -o operation -l [log_path] + -o, --operation start|stop|restart|status + -l, --log log path , ./logs/ default +''' + +import os +import sys +import time +import signal +import atexit +import argparse +import datetime +from multiprocessing import Process +import subprocess +import schedule + + +class Daemon: + ''' + daemon subprocess class. + usage: subclass this daemon and override the run() method. + sys-monitor.pid: in the /tmp/, auto del when unexpected exit. + verbose: debug mode, disabled default. + ''' + + def __init__(self, + pid_file, + log_file, + err_file, + gpu_log, + log_path, + rate=5, + stdin=os.devnull, + stdout=os.devnull, + stderr=os.devnull, + home_dir='.', + umask=0o22, + verbose=0): + self.stdin = stdin + self.stdout = stdout + self.stderr = stderr + self.home_dir = home_dir + self.verbose = verbose + self.pidfile = pid_file + self.logfile = log_file + self.errfile = err_file + self.gpufile = gpu_log + self.logpath = log_path + self.rate = rate + self.umask = umask + self.verbose = verbose + self.daemon_alive = True + + def get_pid(self): + try: + with open(self.pidfile, 'r') as pf: + pid = int(pf.read().strip()) + except IOError: + pid = None + except SystemExit: + pid = None + return pid + + def del_pid(self): + if os.path.exists(self.pidfile): + os.remove(self.pidfile) + + def run(self): + ''' + NOTE: override the method in subclass + ''' + + def gpu_mon(file): + TIMESTAMP = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') + # TODO more elegant way? + cmd = "mthreads-gmi -q | grep -E 'GPU Current Temp|Power Draw|Used|Total|Gpu' | " + cmd += "awk -F ': *' '/GPU Current Temp|Power Draw|Used|Total|Gpu/ { values[(NR-1)%5+1] = $2; } NR % 5 == 0 { print values[4], values[5], values[2], values[1], values[3]; }'" + process = subprocess.Popen(cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding='utf-8') + try: + out = process.communicate(timeout=10) + except subprocess.TimeoutExpired: + process.kill() + out = process.communicate() + + if process.returncode != 0: + result = "error" + result = TIMESTAMP + "\n" + out[0] + "\n" + with open(file, 'a') as f: + f.write(result) + + def timer_gpu_mon(): + gpu_process = Process(target=gpu_mon, args=(self.gpufile, )) + gpu_process.start() + + schedule.every(self.rate).seconds.do(timer_gpu_mon) + while True: + schedule.run_pending() + time.sleep(5) + + def daemonize(self): + if self.verbose >= 1: + print('daemon process starting ...') + try: + pid = os.fork() + if pid > 0: + sys.exit(0) + except OSError as e: + sys.stderr.write('fork #1 failed: %d (%s)\n' % + (e.errno, e.strerror)) + sys.exit(1) + os.chdir(self.home_dir) + os.setsid() + os.umask(self.umask) + try: + pid = os.fork() + if pid > 0: + sys.exit(0) + except OSError as e: + sys.stderr.write('fork #2 failed: %d (%s)\n' % + (e.errno, e.strerror)) + sys.exit(1) + sys.stdout.flush() + sys.stderr.flush() + si = open(self.stdin, 'r') + so = open(self.stdout, 'a+') + if self.stderr: + se = open(self.stderr, 'a+') + else: + se = so + os.dup2(si.fileno(), sys.stdin.fileno()) + os.dup2(so.fileno(), sys.stdout.fileno()) + os.dup2(se.fileno(), sys.stderr.fileno()) + atexit.register(self.del_pid) + pid = str(os.getpid()) + with open(self.pidfile, 'w+') as f: + f.write('%s\n' % pid) + + def start(self): + if not os.path.exists(self.logpath): + os.makedirs(self.logpath) + elif os.path.exists(self.gpufile): + os.remove(self.gpufile) + if self.verbose >= 1: + print('ready to start ......') + # check for a pid file to see if the daemon already runs + pid = self.get_pid() + if pid: + msg = 'pid file %s already exists, is it already running?\n' + sys.stderr.write(msg % self.pidfile) + sys.exit(1) + # start the daemon + self.daemonize() + self.run() + + def stop(self): + if self.verbose >= 1: + print('stopping ...') + pid = self.get_pid() + if not pid: + msg = 'pid file [%s] does not exist. Not running?\n' % self.pidfile + sys.stderr.write(msg) + if os.path.exists(self.pidfile): + os.remove(self.pidfile) + return + # try to kill the daemon process + try: + i = 0 + while 1: + os.kill(pid, signal.SIGTERM) + time.sleep(1) + i = i + 1 + if i % 10 == 0: + os.kill(pid, signal.SIGHUP) + except OSError as err: + err = str(err) + if err.find('No such process') > 0: + if os.path.exists(self.pidfile): + os.remove(self.pidfile) + else: + print(str(err)) + sys.exit(1) + if self.verbose >= 1: + print('Stopped!') + + def restart(self): + self.stop() + self.start() + + def status(self): + pid = self.get_pid() + if pid: + if os.path.exists('/proc/%d' % pid): + return pid + return False + + +def parse_args(): + ''' Check script input parameter. ''' + parse = argparse.ArgumentParser(description='Sys monitor script') + parse.add_argument('-o', + type=str, + metavar='[operation]', + required=True, + help='start|stop|restart|status') + parse.add_argument('-l', + type=str, + metavar='[log_path]', + required=False, + default='./logs/', + help='log path') + args = parse.parse_args() + return args + + +def get_system_info(): + cmd = r"echo OS version:;" + cmd = cmd + r"cat /etc/issue | head -n1 | awk '{print $1, $2, $3}';" + cmd = cmd + r"echo ;" + + cmd = cmd + r"echo OS Kernel version:;" + cmd = cmd + r"uname -r;" + cmd = cmd + r"echo ;" + + cmd = cmd + r"echo Hardware Model:;" + cmd = cmd + r"sudo dmidecode | grep -A9 'System Information' | tail -n +2 | sed 's/^[ \t]*//';" + cmd = cmd + r"echo ;" + + cmd = cmd + r"echo Accelerator Model:;" + cmd = cmd + r"mthreads-gmi -L;" + cmd = cmd + r"echo ;" + + cmd = cmd + r"echo Accelerator Driver version:;" + cmd = cmd + r"mthreads-gmi | grep 'Driver Version' | awk '{print $3}';" + cmd = cmd + r"echo ;" + + cmd = cmd + r"echo Docker version:;" + cmd = cmd + r"docker -v" + + return cmd + + +def main(): + sample_rate1 = 5 + args = parse_args() + operation = args.o + log_path = args.l + pid_fn = str('/tmp/gpu_monitor.pid') + log_fn = str(log_path + '/mthreads_monitor.log') + err_fn = str(log_path + '/mthreads_monitor.err') + # result for gpu + gpu_fn = str(log_path + '/mthreads_monitor.log') + sys_fn = str(log_path + '/sys_info.log') + cmd = get_system_info() + with open(sys_fn, "w") as f: + p = subprocess.Popen(cmd, shell=True, stdout=f, stderr=subprocess.STDOUT) + p.wait() + + subdaemon = Daemon(pid_fn, + log_fn, + err_fn, + gpu_fn, + log_path, + verbose=1, + rate=sample_rate1) + if operation == 'start': + subdaemon.start() + elif operation == 'stop': + subdaemon.stop() + elif operation == 'restart': + subdaemon.restart() + elif operation == 'status': + pid = subdaemon.status() + if pid: + print('process [%s] is running ......' % pid) + else: + print('daemon process [%s] stopped' % pid) + else: + print("invalid argument!") + sys.exit(1) + + +if __name__ == '__main__': + main() diff --git a/training/mthreads/resnet50-pytorch/README.md b/training/mthreads/resnet50-pytorch/README.md new file mode 100644 index 000000000..3b5048860 --- /dev/null +++ b/training/mthreads/resnet50-pytorch/README.md @@ -0,0 +1,52 @@ +### 1. 数据集准备 +[下载ImageNet2012](../../benchmarks/resnet50) + +### 2. 摩尔线程 MTT S系列 GPU配置与运行信息参考 +#### 环境配置 +- ##### 硬件环境 + - 硬件 + - 机器型号: MCCX D800 + - 加速卡型号: MTT S4000 48GB + - CPU型号:Intel(R) Xeon(R) Gold 6430 CPU @ 2.00GHz + - 多机网络类型、带宽: InfiniBand,2*200Gbps + +- ##### 软件环境 + - OS版本:Ubuntu 20.04 LTS + - OS kernel版本: 5.4.0-42-generic + - 加速卡驱动版本:2.2.0 + - Docker版本: 20.10.24 + - 训练框架版本:pytorch-2.0.0+torch_musa-git8614ba1 + - 依赖软件版本: + - musa toolkits: 1.5.0+git3d8791d + - mcc: 1.5.2+git3730bdd + - mublas: 1.2.0+gitd9867b5 + +### 运行情况 + +* 通用指标 + +| 指标名称 | 指标值 | 特殊说明 | +| -------------- | ----------------------- | ------------------------------------- | +| 任务类别 | 图像分类 | | +| 模型 | resnet50 | | +| 数据集 | ImageNet2012 | | +| 数据精度 | precision,见“性能指标” | 可选fp32 | +| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 | +| 硬件设备简称 | MTT S3000 | | +| 硬件存储使用 | mem,见“性能指标” | 通常称为“显存”,单位为GiB | +| 端到端时间 | e2e_time,见“性能指标” | 总时间+Perf初始化等时间 | +| 总吞吐量 | p_whole,见“性能指标” | 实际训练图片数除以总时间(performance_whole) | +| 训练吞吐量 | p_train,见“性能指标” | 不包含每个epoch末尾的评估部分耗时 | +| **计算吞吐量** | **p_core,见“性能指标”** | 不包含数据IO部分的耗时(p3>p2>p1) | +| 训练结果 | acc,见“性能指标” | 单位为top1分类准确率(acc1) | +| 额外修改项 | 无 | | + +* 性能指标 + +| 配置 | precision | fix_hp | e2e_time | p_whole | p_train | p_core | acc | mem | +| ------------------ | --------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | +| 单机1卡(1x1) | fp32 | / | | | | | / | / | +| 单机8卡(1x8) | fp32 |bs=256,lr=0.8 | | | | | /| 25.0/48.0 | +| 单机8卡(1x8) | amp |bs=512,lr=0.2 | | | | | 73.08| 26.2/48.0 | +| 单机8卡(1x8) | bf16 |bs=512,lr=0.2 | | | | | /| 25.7/48.0 | +| 两机8卡(2x8) | fp32 | / | | | | | /| /| diff --git a/training/mthreads/resnet50-pytorch/config/config_S4000x1x1.py b/training/mthreads/resnet50-pytorch/config/config_S4000x1x1.py new file mode 100644 index 000000000..e3437bec1 --- /dev/null +++ b/training/mthreads/resnet50-pytorch/config/config_S4000x1x1.py @@ -0,0 +1,8 @@ +lr = 0.1 +train_batch_size = 256 +eval_batch_size = train_batch_size + +dist_backend = "mccl" +amp = False +fp16 = False + diff --git a/training/mthreads/resnet50-pytorch/config/config_S4000x1x8.py b/training/mthreads/resnet50-pytorch/config/config_S4000x1x8.py new file mode 100644 index 000000000..7b9b4be72 --- /dev/null +++ b/training/mthreads/resnet50-pytorch/config/config_S4000x1x8.py @@ -0,0 +1,8 @@ +lr = 0.8 +train_batch_size = 256 +eval_batch_size = train_batch_size + +dist_backend = "mccl" +amp = False +fp16 = False + diff --git a/training/mthreads/resnet50-pytorch/extern/trainer_adapter.py b/training/mthreads/resnet50-pytorch/extern/trainer_adapter.py new file mode 100644 index 000000000..cc955f536 --- /dev/null +++ b/training/mthreads/resnet50-pytorch/extern/trainer_adapter.py @@ -0,0 +1,35 @@ +import torch +import torch_musa +import config +from driver import dist_pytorch + + +def convert_model(model): + model.to(memory_format=torch.channels_last) + return model + + +def create_grad_scaler(): + """create_grad_scaler for mixed precision training""" + scaler = torch_musa.amp.GradScaler() if config.amp else None + return scaler + + +def train_step(model, batch, optimizer, scaler=None): + """train one step""" + images, target = batch + criterion = torch.nn.CrossEntropyLoss() + if scaler: + with torch.musa.amp.autocast(enabled=True): + output = model(images) + loss = criterion(output, target) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + else: + output = model(images) + loss = criterion(output, target) + loss.backward() + optimizer.step() + + return loss diff --git a/training/run_benchmarks/config/test_conf.py b/training/run_benchmarks/config/test_conf.py index 680349c94..578c98bf4 100644 --- a/training/run_benchmarks/config/test_conf.py +++ b/training/run_benchmarks/config/test_conf.py @@ -1,7 +1,7 @@ '''Test Configs, including''' # -*-coding:utf-8 -*- -# Set accelerator's vendor name, e.g. iluvatar, cambricon, kunlunxin and ascend. +# Set accelerator's vendor name, e.g. iluvatar, cambricon, kunlunxin, ascend and mthreads. # We will run benchmarks in training/ VENDOR = "nvidia" @@ -19,6 +19,8 @@ # "--device=/dev/davinciX --device=/dev/davinci_manager + \ # --device=/dev/devmm_svm --device=/dev/hisi_hdc + \ # -v /usr/local/Ascend/driver -v /usr/local/dcmi -v /usr/local/bin/npu-smi" +# mthreads: +# " --env MTHREADS_VISIBLE_DEVICES=all" ACCE_CONTAINER_OPT = " --gpus all" # XXX_VISIBLE_DEVICE item name in env # possible value of ACCE_VISIBLE_DEVICE_ENV_NAME are: @@ -26,6 +28,7 @@ # MLU_VISIBLE_DEVICES for cambricon # XPU_VISIBLE_DEVICES for kunlunxin # ASCEND_VISIBLE_DEVICES for ascend +# MUSA_VISIBLE_DEVICES for mthreads ACCE_VISIBLE_DEVICE_ENV_NAME = "CUDA_VISIBLE_DEVICES" # Set pip source, which will be used in preparing envs in container @@ -116,6 +119,12 @@ # "longformer:pytorch:R300:1:8:1": "/raid/dataset/longformer_train", # "distilbert:pytorch:R300:1:8:1": "/raid/dataset/distilbert/", # "swin_transformer:pytorch:R300:1:8:1": "/raid/dataset/ImageNet_1k_2012/", - # "tacotron2:pytorch:R300:1:8:1": "/raid/dataset/tacotron2/LJSpeech/" + # "tacotron2:pytorch:R300:1:8:1": "/raid/dataset/tacotron2/LJSpeech/", + + # mthreads cases + # "resnet50:pytorch_2.0:S4000:1:8:1": "/data/flagperf/ImageNet", + # "retinanet:pytorch_2.0:S4000:1:8:1": "/data/flagperf/coco2017", + # "bert_hf:pytorch_2.0:S4000:1:8:1": "/data/flagperf/bert_hf", + # "llama2_7b:deepspeed:S4000:1:8:1": "/data/flagperf/llama/openwebtext", }