Skip to content

Commit

Permalink
update compile compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
layumi committed Jan 24, 2024
1 parent e9d37e1 commit 7854bdc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tqdm import tqdm
from model import ft_net, ft_net_dense, ft_net_hr, ft_net_swin, ft_net_swinv2, ft_net_efficient, ft_net_NAS, ft_net_convnext, PCB, PCB_test
from utils import fuse_all_conv_bn
version = torch.__version__
#fp16
try:
from apex.fp16_utils import *
Expand Down Expand Up @@ -280,7 +281,7 @@ def get_id(img_path):
#if opt.fp16:
# model_structure = network_to_half(model_structure)

if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1: # should be >=7
if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7
print("Compiling model...")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
torch.set_float32_matmul_precision('high')
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def draw_curve(current_epoch):
#optimizer_ft = FP16_Optimizer(optimizer_ft, static_loss_scale = 128.0)
model, optimizer_ft = amp.initialize(model, optimizer_ft, opt_level = "O1")

if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1: # should be >=7 and one gpu
if torch.cuda.get_device_capability()[0]>6 and len(opt.gpu_ids)==1 and int(version[0])>1: # should be >=7 and one gpu
torch.set_float32_matmul_precision('high')
print("Compiling model... The first epoch may be slow, which is expected!")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
Expand Down
3 changes: 2 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch.nn as nn
import os
import torch
import torch.nn as nn
from torch.nn.utils import fuse_conv_bn_eval

class CrossEntropyLabelSmooth(nn.Module):
Expand Down

0 comments on commit 7854bdc

Please sign in to comment.