diff --git a/deepr/model/diffusion_trainer.py b/deepr/model/diffusion_trainer.py index 6edb5c7..047f7c1 100644 --- a/deepr/model/diffusion_trainer.py +++ b/deepr/model/diffusion_trainer.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from accelerate import Accelerator +from dotenv import find_dotenv, load_dotenv from huggingface_hub import Repository from tqdm import tqdm from transformers import get_cosine_schedule_with_warmup @@ -18,6 +19,8 @@ logger = get_logger(__name__) +load_dotenv(find_dotenv()) + def train_diffusion( config: TrainingConfig,