From f09824fd311c147f933afb2b09a77ec3c9b776b2 Mon Sep 17 00:00:00 2001 From: gesen2egee Date: Mon, 7 Oct 2024 18:01:00 +0800 Subject: [PATCH] Update train_network.py --- train_network.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/train_network.py b/train_network.py index 044ec3aa8..bce1109f9 100644 --- a/train_network.py +++ b/train_network.py @@ -205,6 +205,9 @@ def train(self, args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + if args.no_token_padding: + train_dataset_group.disable_token_padding() + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return @@ -1162,6 +1165,11 @@ def setup_parser() -> argparse.ArgumentParser: nargs="*", help="additional arguments for network (key=value) / ネットワークへの追加の引数", ) + parser.add_argument( + "--no_token_padding", + action="store_true", + help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)", + ) parser.add_argument( "--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する" )