From 17095dddea68ec228f0ef5d113a41b5f5dd599dc Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 20 Aug 2023 20:28:36 +0800 Subject: [PATCH] feat: add token weighting support (#13) --- README.md | 2 +- stable-diffusion.cpp | 205 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 200 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index d6fc3aad..32134b0d 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - AVX, AVX2 and AVX512 support for x86 architectures - Original `txt2img` and `img2img` mode - Negative prompt +- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now) - Sampling method - `Euler A` - Supported platforms @@ -30,7 +31,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - [ ] Make inference faster - The current implementation of ggml_conv_2d is slow and has high memory usage - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (eg: token weighting, ...) - [ ] LoRA support - [ ] k-quants support - [ ] Cross-platform reproducibility (perhaps ensuring consistency with the original SD) diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 0c392a1c..114d5523 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -355,6 +355,113 @@ class CLIPTokenizer { } }; +// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345 +// +// Parses a string with attention tokens and returns a list of pairs: text and its associated weight. +// Accepted tokens are: +// (abc) - increases attention to abc by a multiplier of 1.1 +// (abc:3.12) - increases attention to abc by a multiplier of 3.12 +// [abc] - decreases attention to abc by a multiplier of 1.1 +// \( - literal character '(' +// \[ - literal character '[' +// \) - literal character ')' +// \] - literal character ']' +// \\ - literal character '\' +// anything else - just text +// +// >>> parse_prompt_attention('normal text') +// [['normal text', 1.0]] +// >>> parse_prompt_attention('an (important) word') +// [['an ', 1.0], ['important', 1.1], [' word', 1.0]] +// >>> parse_prompt_attention('(unbalanced') +// [['unbalanced', 1.1]] +// >>> parse_prompt_attention('\(literal\]') +// [['(literal]', 1.0]] +// >>> parse_prompt_attention('(unnecessary)(parens)') +// [['unnecessaryparens', 1.1]] +// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') +// [['a ', 1.0], +// ['house', 1.5730000000000004], +// [' ', 1.1], +// ['on', 1.0], +// [' a ', 1.1], +// ['hill', 0.55], +// [', sun, ', 1.1], +// ['sky', 1.4641000000000006], +// ['.', 1.1]] +std::vector> parse_prompt_attention(const std::string& text) { + std::vector> res; + std::vector round_brackets; + std::vector square_brackets; + + float round_bracket_multiplier = 1.1f; + float square_bracket_multiplier = 1 / 1.1f; + + std::regex re_attention(R"(\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|]|[^\\()\[\]:]+|:)"); + std::regex re_break(R"(\s*\bBREAK\b\s*)"); + + auto multiply_range = [&](int start_position, float multiplier) { + for (int p = start_position; p < res.size(); ++p) { + res[p].second *= multiplier; + } + }; + + std::smatch m; + std::string remaining_text = text; + + while (std::regex_search(remaining_text, m, re_attention)) { + std::string text = m[0]; + std::string weight = m[1]; + + if (text == "(") { + round_brackets.push_back(res.size()); + } else if (text == "[") { + square_brackets.push_back(res.size()); + } else if (!weight.empty()) { + if (!round_brackets.empty()) { + multiply_range(round_brackets.back(), std::stod(weight)); + round_brackets.pop_back(); + } + } else if (text == ")" && !round_brackets.empty()) { + multiply_range(round_brackets.back(), round_bracket_multiplier); + round_brackets.pop_back(); + } else if (text == "]" && !square_brackets.empty()) { + multiply_range(square_brackets.back(), square_bracket_multiplier); + square_brackets.pop_back(); + } else if (text == "\\(") { + res.push_back({text.substr(1), 1.0f}); + } else { + res.push_back({text, 1.0f}); + } + + remaining_text = m.suffix(); + } + + for (int pos : round_brackets) { + multiply_range(pos, round_bracket_multiplier); + } + + for (int pos : square_brackets) { + multiply_range(pos, square_bracket_multiplier); + } + + if (res.empty()) { + res.push_back({"", 1.0f}); + } + + int i = 0; + while (i + 1 < res.size()) { + if (res[i].second == res[i + 1].second) { + res[i].first += res[i + 1].first; + res.erase(res.begin() + i + 1); + } else { + ++i; + } + } + + return res; +} + /*================================================ FrozenCLIPEmbedder ================================================*/ struct ResidualAttentionBlock { @@ -639,6 +746,61 @@ struct FrozenCLIPEmbedder { } }; +// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 +struct FrozenCLIPEmbedderWithCustomWords { + CLIPTokenizer tokenizer; + CLIPTextModel text_model; + + std::pair, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + std::vector tokens; + std::vector weights; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + std::vector curr_tokens = tokenizer.encode(curr_text); + tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); + weights.insert(weights.end(), curr_tokens.size(), curr_weight); + } + tokens.insert(tokens.begin(), BOS_TOKEN_ID); + weights.insert(weights.begin(), 1.0); + + if (max_length > 0) { + if (tokens.size() > max_length - 1) { + tokens.resize(max_length - 1); + weights.resize(max_length - 1); + } else { + if (padding) { + tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID); + weights.insert(weights.end(), max_length - 1 - weights.size(), 1.0); + } + } + } + tokens.push_back(EOS_TOKEN_ID); + weights.push_back(1.0); + + // for (int i = 0; i < tokens.size(); i++) { + // std::cout << tokens[i] << ":" << weights[i] << ", "; + // } + // std::cout << std::endl; + + return {tokens, weights}; + } +}; + /*==================================================== UnetModel =====================================================*/ struct ResBlock { @@ -2489,7 +2651,7 @@ class StableDiffusionGGML { size_t max_params_mem_size = 0; size_t max_rt_mem_size = 0; - FrozenCLIPEmbedder cond_stage_model; + FrozenCLIPEmbedderWithCustomWords cond_stage_model; UNetModel diffusion_model; AutoEncoderKL first_stage_model; @@ -2784,9 +2946,11 @@ class StableDiffusionGGML { } ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) { - std::vector tokens = cond_stage_model.tokenizer.tokenize(text, - cond_stage_model.text_model.max_position_embeddings, - true); + auto tokens_and_weights = cond_stage_model.tokenize(text, + cond_stage_model.text_model.max_position_embeddings, + true); + std::vector& tokens = tokens_and_weights.first; + std::vector& weights = tokens_and_weights.second; size_t ctx_size = 1 * 1024 * 1024; // 1MB // calculate the amount of memory required { @@ -2848,10 +3012,39 @@ class StableDiffusionGGML { int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states); - copy_ggml_tensor(result, hidden_states); + ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states); // [N, n_token, hidden_size] + + { + int64_t nelements = ggml_nelements(hidden_states); + float original_mean = 0.f; + float new_mean = 0.f; + float* vec = (float*)hidden_states->data; + for (int i = 0; i < nelements; i++) { + original_mean += vec[i] / nelements * 1.0f; + } + + for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) { + for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) { + for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) { + float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2); + value *= weights[i1]; + ggml_tensor_set_f32(result, value, i0, i1, i2); + } + } + } + + vec = (float*)result->data; + for (int i = 0; i < nelements; i++) { + new_mean += vec[i] / nelements * 1.0f; + } + + for (int i = 0; i < nelements; i++) { + vec[i] = vec[i] * (original_mean / new_mean); + } + } // print_ggml_tensor(result); + size_t rt_mem_size = ctx_size + ggml_curr_max_dynamic_size(); if (rt_mem_size > max_rt_mem_size) { max_rt_mem_size = rt_mem_size;