Skip to content

Commit

Permalink
Expand Lora triggers directly into the text box on auto-complete
Browse files Browse the repository at this point in the history
- this also means if you don't auto-complete but type it out the triggers will not be inserted
  • Loading branch information
Acly committed Oct 13, 2024
1 parent ae66764 commit b7ce179
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
3 changes: 2 additions & 1 deletion ai_diffusion/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def replace(match: re.Match[str]):
lora_normalized = file.name.lower()
if input == lora_filename or input == lora_normalized:
lora_file = file
break

if not lora_file:
error = _("LoRA not found") + f": {input}"
Expand All @@ -64,7 +65,7 @@ def replace(match: re.Match[str]):
raise Exception(error)

loras.append(LoraInput(lora_file.id, lora_strength))
return lora_file.meta("lora_triggers", "")
return ""

prompt = _pattern_lora.sub(replace, prompt)
return prompt.strip(), loras
Expand Down
18 changes: 9 additions & 9 deletions ai_diffusion/ui/autocomplete.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,6 @@ def _blend_colors(self, from_, to, factor):


class PromptAutoComplete:
_completer: QCompleter
_completion_prefix: str
_completion_suffix: str

def __init__(self, widget: QLineEdit):
self._widget = widget
Expand Down Expand Up @@ -235,17 +232,20 @@ def check_completion(self):
self._completer.complete(rect)

def _insert_completion(self, completion):
if not self._current_text().startswith("<lora:"):
triggers = ""
if self._current_text().startswith("<lora:"):
if file := root.files.loras.find(f"{completion}.safetensors"):
triggers = " " + file.meta("lora_triggers", "")
else: # tag completion
# escape () in tags so they won't be interpreted as prompt weights
completion = completion.replace("(", "\\(").replace(")", "\\)")
text = self._widget.text()
pos = self._widget.cursorPosition()
prefix_len = len(self._completion_prefix)
text = text[: pos - prefix_len] + completion + self._completion_suffix + text[pos:]
start_pos = pos - len(self._completion_prefix)
fill = completion + self._completion_suffix + triggers
text = text[:start_pos] + fill + text[pos:]
self._widget.setText(text)
self._widget.setCursorPosition(
pos - prefix_len + len(completion) + len(self._completion_suffix)
)
self._widget.setCursorPosition(start_pos + len(fill))

@property
def is_active(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_extract_loras_meta():
loras.set_meta(lora, "lora_triggers", "zippity")

assert extract_loras("a ship <lora:zap> zap", loras) == (
"a ship zippity zap",
"a ship zap", # triggers are inserted on auto-complete, not at extraction
[LoraInput(lora.id, 0.5)],
)

Expand Down

0 comments on commit b7ce179

Please sign in to comment.