From 2955b310af9c02dfe51d9f238597753378df7dfd Mon Sep 17 00:00:00 2001 From: um1 Date: Thu, 28 Dec 2023 09:21:25 +0000 Subject: [PATCH] update fp32 --- test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test.py b/test.py index e52d1a5..a6b1c47 100644 --- a/test.py +++ b/test.py @@ -25,6 +25,7 @@ from apex.fp16_utils import * except ImportError: # will be 3.x series print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') + ###################################################################### # Options # -------- @@ -282,6 +283,7 @@ def get_id(img_path): if torch.cuda.get_device_capability()[0]>6: # should be >=7 print("Compiling model...") # https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0 + torch.set_float32_matmul_precision('high') model_structure = torch.compile(model_structure, mode="default", dynamic=True) # pytorch 2.0 model = load_network(model_structure)