diff --git a/README.md b/README.md index cb3803f03..8234a89e4 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### 8 Jun. 2023, 2023/06/08 + +- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training. +- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。 + ### 6 Jun. 2023, 2023/06/06 - Fix `train_network.py` to probably work with older versions of LyCORIS. diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 0cf0d1e23..8b44874b9 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -265,11 +265,6 @@ def get_unweighted_text_embeddings( text_embedding = enc_out["hidden_states"][-clip_skip] text_embedding = text_encoder.text_model.final_layer_norm(text_embedding) - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0] - if no_boseos_middle: if i == 0: # discard the ending token @@ -284,7 +279,12 @@ def get_unweighted_text_embeddings( text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: - text_embeddings = text_encoder(text_input)[0] + if clip_skip is None or clip_skip == 1: + text_embeddings = text_encoder(text_input)[0] + else: + enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings) return text_embeddings diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 3e04b8876..58b1171e1 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -245,11 +245,6 @@ def get_unweighted_text_embeddings( text_embedding = enc_out["hidden_states"][-clip_skip] text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding) - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0] - if no_boseos_middle: if i == 0: # discard the ending token @@ -264,7 +259,12 @@ def get_unweighted_text_embeddings( text_embeddings.append(text_embedding) text_embeddings = torch.concat(text_embeddings, axis=1) else: - text_embeddings = pipe.text_encoder(text_input)[0] + if clip_skip is None or clip_skip == 1: + text_embeddings = pipe.text_encoder(text_input)[0] + else: + enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True) + text_embeddings = enc_out["hidden_states"][-clip_skip] + text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings) return text_embeddings