Skip to content

Commit

Permalink
feat: add token weighting support (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet authored Aug 20, 2023
1 parent 7132027 commit 17095dd
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- AVX, AVX2 and AVX512 support for x86 architectures
- Original `txt2img` and `img2img` mode
- Negative prompt
- [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now)
- Sampling method
- `Euler A`
- Supported platforms
Expand All @@ -30,7 +31,6 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- [ ] Make inference faster
- The current implementation of ggml_conv_2d is slow and has high memory usage
- [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d)
- [ ] [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (eg: token weighting, ...)
- [ ] LoRA support
- [ ] k-quants support
- [ ] Cross-platform reproducibility (perhaps ensuring consistency with the original SD)
Expand Down
205 changes: 199 additions & 6 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,113 @@ class CLIPTokenizer {
}
};

// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345
//
// Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
// Accepted tokens are:
// (abc) - increases attention to abc by a multiplier of 1.1
// (abc:3.12) - increases attention to abc by a multiplier of 3.12
// [abc] - decreases attention to abc by a multiplier of 1.1
// \( - literal character '('
// \[ - literal character '['
// \) - literal character ')'
// \] - literal character ']'
// \\ - literal character '\'
// anything else - just text
//
// >>> parse_prompt_attention('normal text')
// [['normal text', 1.0]]
// >>> parse_prompt_attention('an (important) word')
// [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
// >>> parse_prompt_attention('(unbalanced')
// [['unbalanced', 1.1]]
// >>> parse_prompt_attention('\(literal\]')
// [['(literal]', 1.0]]
// >>> parse_prompt_attention('(unnecessary)(parens)')
// [['unnecessaryparens', 1.1]]
// >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
// [['a ', 1.0],
// ['house', 1.5730000000000004],
// [' ', 1.1],
// ['on', 1.0],
// [' a ', 1.1],
// ['hill', 0.55],
// [', sun, ', 1.1],
// ['sky', 1.4641000000000006],
// ['.', 1.1]]
std::vector<std::pair<std::string, float>> parse_prompt_attention(const std::string& text) {
std::vector<std::pair<std::string, float>> res;
std::vector<int> round_brackets;
std::vector<int> square_brackets;

float round_bracket_multiplier = 1.1f;
float square_bracket_multiplier = 1 / 1.1f;

std::regex re_attention(R"(\\\(|\\\)|\\\[|\\]|\\\\|\\|\(|\[|:([+-]?[.\d]+)\)|\)|]|[^\\()\[\]:]+|:)");
std::regex re_break(R"(\s*\bBREAK\b\s*)");

auto multiply_range = [&](int start_position, float multiplier) {
for (int p = start_position; p < res.size(); ++p) {
res[p].second *= multiplier;
}
};

std::smatch m;
std::string remaining_text = text;

while (std::regex_search(remaining_text, m, re_attention)) {
std::string text = m[0];
std::string weight = m[1];

if (text == "(") {
round_brackets.push_back(res.size());
} else if (text == "[") {
square_brackets.push_back(res.size());
} else if (!weight.empty()) {
if (!round_brackets.empty()) {
multiply_range(round_brackets.back(), std::stod(weight));
round_brackets.pop_back();
}
} else if (text == ")" && !round_brackets.empty()) {
multiply_range(round_brackets.back(), round_bracket_multiplier);
round_brackets.pop_back();
} else if (text == "]" && !square_brackets.empty()) {
multiply_range(square_brackets.back(), square_bracket_multiplier);
square_brackets.pop_back();
} else if (text == "\\(") {
res.push_back({text.substr(1), 1.0f});
} else {
res.push_back({text, 1.0f});
}

remaining_text = m.suffix();
}

for (int pos : round_brackets) {
multiply_range(pos, round_bracket_multiplier);
}

for (int pos : square_brackets) {
multiply_range(pos, square_bracket_multiplier);
}

if (res.empty()) {
res.push_back({"", 1.0f});
}

int i = 0;
while (i + 1 < res.size()) {
if (res[i].second == res[i + 1].second) {
res[i].first += res[i + 1].first;
res.erase(res.begin() + i + 1);
} else {
++i;
}
}

return res;
}

/*================================================ FrozenCLIPEmbedder ================================================*/

struct ResidualAttentionBlock {
Expand Down Expand Up @@ -639,6 +746,61 @@ struct FrozenCLIPEmbedder {
}
};

// Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283
struct FrozenCLIPEmbedderWithCustomWords {
CLIPTokenizer tokenizer;
CLIPTextModel text_model;

std::pair<std::vector<int>, std::vector<float>> tokenize(std::string text,
size_t max_length = 0,
bool padding = false) {
auto parsed_attention = parse_prompt_attention(text);

{
std::stringstream ss;
ss << "[";
for (const auto& item : parsed_attention) {
ss << "['" << item.first << "', " << item.second << "], ";
}
ss << "]";
LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str());
}

std::vector<int> tokens;
std::vector<float> weights;
for (const auto& item : parsed_attention) {
const std::string& curr_text = item.first;
float curr_weight = item.second;
std::vector<int> curr_tokens = tokenizer.encode(curr_text);
tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end());
weights.insert(weights.end(), curr_tokens.size(), curr_weight);
}
tokens.insert(tokens.begin(), BOS_TOKEN_ID);
weights.insert(weights.begin(), 1.0);

if (max_length > 0) {
if (tokens.size() > max_length - 1) {
tokens.resize(max_length - 1);
weights.resize(max_length - 1);
} else {
if (padding) {
tokens.insert(tokens.end(), max_length - 1 - tokens.size(), PAD_TOKEN_ID);
weights.insert(weights.end(), max_length - 1 - weights.size(), 1.0);
}
}
}
tokens.push_back(EOS_TOKEN_ID);
weights.push_back(1.0);

// for (int i = 0; i < tokens.size(); i++) {
// std::cout << tokens[i] << ":" << weights[i] << ", ";
// }
// std::cout << std::endl;

return {tokens, weights};
}
};

/*==================================================== UnetModel =====================================================*/

struct ResBlock {
Expand Down Expand Up @@ -2489,7 +2651,7 @@ class StableDiffusionGGML {
size_t max_params_mem_size = 0;
size_t max_rt_mem_size = 0;

FrozenCLIPEmbedder cond_stage_model;
FrozenCLIPEmbedderWithCustomWords cond_stage_model;
UNetModel diffusion_model;
AutoEncoderKL first_stage_model;

Expand Down Expand Up @@ -2784,9 +2946,11 @@ class StableDiffusionGGML {
}

ggml_tensor* get_learned_condition(ggml_context* res_ctx, const std::string& text) {
std::vector<int32_t> tokens = cond_stage_model.tokenizer.tokenize(text,
cond_stage_model.text_model.max_position_embeddings,
true);
auto tokens_and_weights = cond_stage_model.tokenize(text,
cond_stage_model.text_model.max_position_embeddings,
true);
std::vector<int>& tokens = tokens_and_weights.first;
std::vector<float>& weights = tokens_and_weights.second;
size_t ctx_size = 1 * 1024 * 1024; // 1MB
// calculate the amount of memory required
{
Expand Down Expand Up @@ -2848,10 +3012,39 @@ class StableDiffusionGGML {
int64_t t1 = ggml_time_ms();
LOG_DEBUG("computing condition graph completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);

ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states);
copy_ggml_tensor(result, hidden_states);
ggml_tensor* result = ggml_dup_tensor(res_ctx, hidden_states); // [N, n_token, hidden_size]

{
int64_t nelements = ggml_nelements(hidden_states);
float original_mean = 0.f;
float new_mean = 0.f;
float* vec = (float*)hidden_states->data;
for (int i = 0; i < nelements; i++) {
original_mean += vec[i] / nelements * 1.0f;
}

for (int i2 = 0; i2 < hidden_states->ne[2]; i2++) {
for (int i1 = 0; i1 < hidden_states->ne[1]; i1++) {
for (int i0 = 0; i0 < hidden_states->ne[0]; i0++) {
float value = ggml_tensor_get_f32(hidden_states, i0, i1, i2);
value *= weights[i1];
ggml_tensor_set_f32(result, value, i0, i1, i2);
}
}
}

vec = (float*)result->data;
for (int i = 0; i < nelements; i++) {
new_mean += vec[i] / nelements * 1.0f;
}

for (int i = 0; i < nelements; i++) {
vec[i] = vec[i] * (original_mean / new_mean);
}
}

// print_ggml_tensor(result);

size_t rt_mem_size = ctx_size + ggml_curr_max_dynamic_size();
if (rt_mem_size > max_rt_mem_size) {
max_rt_mem_size = rt_mem_size;
Expand Down

0 comments on commit 17095dd

Please sign in to comment.