Skip to content

Commit

Permalink
add dynamic
Browse files Browse the repository at this point in the history
  • Loading branch information
um1 committed Dec 28, 2023
1 parent c1da1fc commit 55ce1cd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def get_id(img_path):
if torch.cuda.get_device_capability()[0]>6: # should be >=7
print("Compiling model...")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
model_structure = torch.compile(model_structure, mode="reduce-overhead", fullgraph=True) # pytorch 2.0
model_structure = torch.compile(model_structure, mode="default", dynamic=True) # pytorch 2.0

model = load_network(model_structure)

Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def save_network(network, epoch_label):
if torch.cuda.get_device_capability()[0]>6: # should be >=7
print("Compiling model... The first epoch may be slow, which is expected!")
# https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0
model = torch.compile(model, mode="reduce-overhead", fullgraph=True) # pytorch 2.0
model = torch.compile(model, mode="reduce-overhead", dynamic = True) # pytorch 2.0

optim_name = optim.SGD #apex.optimizers.FusedSGD
if opt.FSGD: # apex is needed
Expand Down

0 comments on commit 55ce1cd

Please sign in to comment.