This is the official repository for Worst-case forget set on prompt-wise unlearning.
This repo is based on the environment ip2p
.
Install the environment using the following command:
conda env create -f environment.yaml
- Utilize ckpt_path and data_path to specify the model and dataset path. Employ ESD with SignSGD to identify the worst-case forget set on prompt-wise unlearning.
python stable_diffusion/train-scripts/select-pair.py --train_method=xattn --SignESD --lr 1e-5 --w_lr 100.0 --devices 0,0 --output_dir results/select-pair --ckpt_path {the path of origin model} --data_path {the path of dataset}
- Employ ESD with AdamW to unlearn the worst-case forget set. Subsequently, utilize the unlearned model to generate associated images.
python stable_diffusion/train-scripts/evaluate_selection.py --train_method=xattn --lr 3e-7 --devices 0,0 --output_dir results/evaluation --ckpt_path {the path of origin model} --data_path {the path of dataset} --w_path {path of the selection weight}
If random_choice, the original model will unlearn on a random forget set.