Skip to content

Commit

Permalink
Merge pull request #580 from kohya-ss/dev
Browse files Browse the repository at this point in the history
fix clip skip not working in weighted caption training and sample gen
  • Loading branch information
kohya-ss authored Jun 8, 2023
2 parents 363f1df + 8088c04 commit 6417f5d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser

## Change History

### 8 Jun. 2023, 2023/06/08

- Fixed a bug where clip skip did not work when training with weighted captions (`--weighted_captions` specified) and when generating sample images during training.
- 重みづけキャプションでの学習時(`--weighted_captions`指定時)および学習中のサンプル画像生成時にclip skipが機能しない不具合を修正しました。

### 6 Jun. 2023, 2023/06/06

- Fix `train_network.py` to probably work with older versions of LyCORIS.
Expand Down
12 changes: 6 additions & 6 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,6 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)

# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = text_encoder(text_input_chunk, attention_mask=None)[0]

if no_boseos_middle:
if i == 0:
# discard the ending token
Expand All @@ -284,7 +279,12 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = text_encoder(text_input)[0]
if clip_skip is None or clip_skip == 1:
text_embeddings = text_encoder(text_input)[0]
else:
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings


Expand Down
12 changes: 6 additions & 6 deletions library/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,6 @@ def get_unweighted_text_embeddings(
text_embedding = enc_out["hidden_states"][-clip_skip]
text_embedding = pipe.text_encoder.text_model.final_layer_norm(text_embedding)

# cover the head and the tail by the starting and the ending tokens
text_input_chunk[:, 0] = text_input[0, 0]
text_input_chunk[:, -1] = text_input[0, -1]
text_embedding = pipe.text_encoder(text_input_chunk, attention_mask=None)[0]

if no_boseos_middle:
if i == 0:
# discard the ending token
Expand All @@ -264,7 +259,12 @@ def get_unweighted_text_embeddings(
text_embeddings.append(text_embedding)
text_embeddings = torch.concat(text_embeddings, axis=1)
else:
text_embeddings = pipe.text_encoder(text_input)[0]
if clip_skip is None or clip_skip == 1:
text_embeddings = pipe.text_encoder(text_input)[0]
else:
enc_out = pipe.text_encoder(text_input, output_hidden_states=True, return_dict=True)
text_embeddings = enc_out["hidden_states"][-clip_skip]
text_embeddings = pipe.text_encoder.text_model.final_layer_norm(text_embeddings)
return text_embeddings


Expand Down

0 comments on commit 6417f5d

Please sign in to comment.