Skip to content

Commit

Permalink
update fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
um1 committed Dec 28, 2023
1 parent 55ce1cd commit 2955b31
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from apex.fp16_utils import *
except ImportError: # will be 3.x series
print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')

######################################################################
# Options
# --------
Expand Down Expand Up @@ -282,6 +283,7 @@ def get_id(img_path):
if torch.cuda.get_device_capability()[0]>6: # should be >=7
print("Compiling model...")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
torch.set_float32_matmul_precision('high')
model_structure = torch.compile(model_structure, mode="default", dynamic=True) # pytorch 2.0

model = load_network(model_structure)
Expand Down

0 comments on commit 2955b31

Please sign in to comment.