Skip to content

Commit

Permalink
Added context menu to history widget:
Browse files Browse the repository at this point in the history
- copy previous (negative) prompt
- copy previous seed
- fix seed sometimes not fitting into int32
- persist negative prompt in history
- new history section when negative prompt changes
- close popup when it is clicked (outside a child widget)
  • Loading branch information
Acly committed Dec 30, 2023
1 parent bdbf52d commit 1df9afa
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 63 deletions.
40 changes: 26 additions & 14 deletions ai_diffusion/jobs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from collections import deque
from dataclasses import dataclass, fields
from datetime import datetime
from enum import Enum, Flag
from typing import Deque, NamedTuple
Expand All @@ -25,24 +26,35 @@ class JobKind(Enum):
live_preview = 3


@dataclass
class JobParams:
bounds: Bounds
prompt: str
negative_prompt: str = ""
seed: int = 0

@classmethod
def equal_ignore_seed(cls, a: JobParams | None, b: JobParams | None):
if a is None or b is None:
return a is b
field_names = (f.name for f in fields(cls) if not f.name == "seed")
return all(getattr(a, name) == getattr(b, name) for name in field_names)


class Job:
id: str | None
kind: JobKind
state = JobState.queued
prompt: str
bounds: Bounds
seed: int
params: JobParams
control: "control.ControlLayer | None" = None
timestamp: datetime
results: ImageCollection
_in_use: dict[int, bool]

def __init__(self, id: str | None, kind: JobKind, prompt: str, bounds: Bounds, seed=0):
def __init__(self, id: str | None, kind: JobKind, params: JobParams):
self.id = id
self.kind = kind
self.prompt = prompt
self.bounds = bounds
self.seed = seed
self.params = params
self.timestamp = datetime.now()
self.results = ImageCollection()
self._in_use = {}
Expand Down Expand Up @@ -72,20 +84,20 @@ def __init__(self):
super().__init__()
self._entries = deque()

def add(self, kind: JobKind, id: str, prompt: str, bounds: Bounds, seed: int):
return self._add(Job(id, kind, prompt, bounds, seed))
def add(self, kind: JobKind, id: str, prompt: str, negative: str, bounds: Bounds, seed: int):
return self.add_job(Job(id, kind, JobParams(bounds, prompt, negative, seed)))

def add_control(self, control: "control.ControlLayer", bounds: Bounds):
job = Job(None, JobKind.control_layer, f"[Control] {control.mode.text}", bounds)
job = Job(None, JobKind.control_layer, JobParams(bounds, f"[Control] {control.mode.text}"))
job.control = control
return self._add(job)
return self.add_job(job)

def add_upscale(self, bounds: Bounds, seed: int):
name = f"[Upscale] {bounds.width}x{bounds.height}"
job = Job(None, JobKind.upscaling, name, bounds, seed)
return self._add(job)
job = Job(None, JobKind.upscaling, JobParams(bounds, name, seed=seed))
return self.add_job(job)

def _add(self, job: Job):
def add_job(self, job: Job):
self._entries.append(job)
self.count_changed.emit()
return job
Expand Down
23 changes: 12 additions & 11 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,11 @@ async def _generate(
client, style, image, mask, conditioning, strength, seed, is_live
)

job_kind = JobKind.live_preview if is_live else JobKind.diffusion
pos, neg = conditioning.prompt, conditioning.negative_prompt
for i in range(count):
job_id = await client.enqueue(job)
job_kind = JobKind.live_preview if is_live else JobKind.diffusion
self.jobs.add(job_kind, job_id, conditioning.prompt, bounds, job.seed)
self.jobs.add(job_kind, job_id, pos, neg, bounds, job.seed)
job.seed = seed + (i + 1) * settings.batch_size

def upscale_image(self):
Expand Down Expand Up @@ -295,16 +296,16 @@ def update_preview(self):
def show_preview(self, job_id: str, index: int, name_prefix="Preview"):
job = self.jobs.find(job_id)
assert job is not None, "Cannot show preview, invalid job id"
name = f"[{name_prefix}] {job.prompt}"
name = f"[{name_prefix}] {job.params.prompt}"
if self._layer and self._layer.parentNode() is None:
self._layer = None
if self._layer is not None:
self._layer.setName(name)
self._doc.set_layer_content(self._layer, job.results[index], job.bounds)
self._doc.set_layer_content(self._layer, job.results[index], job.params.bounds)
self._doc.move_to_top(self._layer)
else:
self._layer = self._doc.insert_layer(
name, job.results[index], job.bounds, make_active=False
name, job.results[index], job.params.bounds, make_active=False
)
self._layer.setLocked(True)

Expand All @@ -326,10 +327,10 @@ def add_control_layer(self, job: Job, result: dict | None):
assert job.kind is JobKind.control_layer and job.control
if job.control.mode is ControlMode.pose and result is not None:
pose = Pose.from_open_pose_json(result)
pose.scale(job.bounds.extent)
return self._doc.insert_vector_layer(job.prompt, pose.to_svg())
pose.scale(job.params.bounds.extent)
return self._doc.insert_vector_layer(job.params.prompt, pose.to_svg())
elif len(job.results) > 0:
return self._doc.insert_layer(job.prompt, job.results[0], job.bounds)
return self._doc.insert_layer(job.params.prompt, job.results[0], job.params.bounds)
return self.document.active_layer # Execution was cached and no image was produced

def add_upscale_layer(self, job: Job):
Expand All @@ -338,9 +339,9 @@ def add_upscale_layer(self, job: Job):
if self._layer:
self._layer.remove()
self._layer = None
self._doc.resize(job.bounds.extent)
self._doc.resize(job.params.bounds.extent)
self.upscale.target_extent_changed.emit(self.upscale.target_extent)
self._doc.insert_layer(job.prompt, job.results[0], job.bounds)
self._doc.insert_layer(job.params.prompt, job.results[0], job.params.bounds)

def set_workspace(self, workspace: Workspace):
if self.workspace is Workspace.live:
Expand Down Expand Up @@ -455,7 +456,7 @@ def toggle(self, active: bool):
def handle_job_finished(self, job: Job):
if job.kind is JobKind.live_preview:
if len(job.results) > 0:
self.set_result(job.results[0], job.bounds)
self.set_result(job.results[0], job.params.bounds)
self.is_active = self._is_active and self._model.document.is_active
if self.is_active:
self._model.generate_live()
Expand Down
17 changes: 6 additions & 11 deletions ai_diffusion/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .image import Bounds, Image, ImageCollection, ImageFileFormat
from .model import Model
from .control import ControlLayer
from .jobs import Job, JobKind
from .jobs import Job, JobKind, JobParams
from .style import Style, Styles
from .properties import serialize, deserialize
from .settings import settings
Expand All @@ -22,15 +22,14 @@
@dataclass
class _HistoryResult:
id: str
prompt: str
bounds: Bounds
slot: int # annotation slot where images are stored
offsets: list[int] # offsets in bytes for result images
seed: int = 0
params: JobParams

@staticmethod
def from_dict(data: dict[str, Any]):
data["bounds"] = Bounds(*data["bounds"])
data["params"]["bounds"] = Bounds(*data["params"]["bounds"])
data["params"] = JobParams(**data["params"])
return _HistoryResult(**data)


Expand Down Expand Up @@ -80,9 +79,7 @@ def _load(self, model: Model, state_bytes: bytes):
for result in state.get("history", []):
item = _HistoryResult.from_dict(result)
if images_bytes := model.document.find_annotation(f"result{item.slot}"):
job = model.jobs.add(
JobKind.diffusion, item.id, item.prompt, item.bounds, item.seed
)
job = model.jobs.add_job(Job(item.id, JobKind.diffusion, item.params))
results = _deserialize_images(images_bytes, item.offsets, item.slot)
model.jobs.set_results(job, results)
model.jobs.notify_finished(job)
Expand Down Expand Up @@ -110,9 +107,7 @@ def _save_results(self, job: Job):
self._slot_index += 1
image_data, image_offsets = _serialize_images(job.results)
self._model.document.annotate(f"result{slot}", image_data)
self._history.append(
_HistoryResult(job.id or "", job.prompt, job.bounds, slot, image_offsets, job.seed)
)
self._history.append(_HistoryResult(job.id or "", slot, image_offsets, job.params))
self._memory_used[slot] = image_data.size()
self._prune()
self._save()
Expand Down
77 changes: 53 additions & 24 deletions ai_diffusion/ui/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
QListWidgetItem,
QListView,
QSizePolicy,
QMenu,
)

from ..properties import Binding, bind, Bind
from ..image import Bounds, Extent, Image
from ..jobs import Job, JobQueue, JobState, JobKind
from ..jobs import Job, JobQueue, JobState, JobKind, JobParams
from ..model import Model
from ..root import root
from ..settings import settings
Expand All @@ -34,10 +35,9 @@


class HistoryWidget(QListWidget):
_jobs: JobQueue
_model: Model
_connections: list[QMetaObject.Connection]
_last_prompt: str | None = None
_last_bounds: Bounds | None = None
_last_job_params: JobParams | None = None

item_activated = pyqtSignal(QListWidgetItem)

Expand All @@ -60,7 +60,7 @@ class HistoryWidget(QListWidget):

def __init__(self, parent):
super().__init__(parent)
self._jobs = JobQueue()
self._model = root.active_model
self._connections = []

self.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
Expand All @@ -81,14 +81,18 @@ def __init__(self, parent):
if scrollbar := self.verticalScrollBar():
scrollbar.valueChanged.connect(self.update_apply_button)

self.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu)
self.customContextMenuRequested.connect(self._show_context_menu)

@property
def jobs(self):
return self._jobs
def model_(self):
return self._model

@jobs.setter
def jobs(self, jobs: JobQueue):
@model_.setter
def model_(self, model: Model):
Binding.disconnect_all(self._connections)
self._jobs = jobs
self._model = model
jobs = model.jobs
self._connections = [
jobs.selection_changed.connect(self.update_selection),
self.itemSelectionChanged.connect(self.select_item),
Expand All @@ -108,15 +112,14 @@ def add(self, job: Job):
scrollbar and scrollbar.isVisible() and scrollbar.value() >= scrollbar.maximum() - 4
)

if self._last_prompt != job.prompt or self._last_bounds != job.bounds:
self._last_prompt = job.prompt
self._last_bounds = job.bounds
prompt = job.prompt if job.prompt != "" else "<no prompt>"
if not JobParams.equal_ignore_seed(self._last_job_params, job.params):
self._last_job_params = job.params
prompt = job.params.prompt if job.params.prompt != "" else "<no prompt>"

header = QListWidgetItem(f"{job.timestamp:%H:%M} - {prompt}")
header.setFlags(Qt.ItemFlag.NoItemFlags)
header.setData(Qt.ItemDataRole.UserRole, job.id)
header.setData(Qt.ItemDataRole.ToolTipRole, job.prompt)
header.setData(Qt.ItemDataRole.ToolTipRole, job.params.prompt)
header.setSizeHint(QSize(9999, self.fontMetrics().lineSpacing() + 4))
header.setTextAlignment(Qt.AlignmentFlag.AlignLeft)
self.addItem(header)
Expand All @@ -127,7 +130,7 @@ def add(self, job: Job):
item.setData(Qt.ItemDataRole.UserRole + 1, i)
item.setData(
Qt.ItemDataRole.ToolTipRole,
f"{job.prompt}\nClick to toggle preview, double-click to apply.",
f"{job.params.prompt}\nClick to toggle preview, double-click to apply.",
)
self.addItem(item)

Expand All @@ -146,12 +149,12 @@ def remove(self, job: Job):
item = self.item(i)
break
if item_was_selected:
self._jobs.selection = None
self._model.jobs.selection = None
else:
self.update_apply_button() # selection may have moved

def update_selection(self):
selection = self._jobs.selection
selection = self._model.jobs.selection
if selection is None:
self.clearSelection()
elif selection:
Expand All @@ -178,15 +181,15 @@ def update_apply_button(self):

def update_image_thumbnail(self, id: JobQueue.Item):
if item := self._find(id):
job = ensure(self._jobs.find(id.job))
job = ensure(self._model.jobs.find(id.job))
item.setIcon(self._image_thumbnail(job, id.image))

def select_item(self):
items = self.selectedItems()
if len(items) > 0:
self._jobs.selection = self._item_data(items[0])
self._model.jobs.selection = self._item_data(items[0])
else:
self._jobs.selection = None
self._model.jobs.selection = None

def _activate_selection(self):
items = self.selectedItems()
Expand All @@ -198,12 +201,20 @@ def is_finished(self, job: Job):

def rebuild(self):
self.clear()
for job in filter(self.is_finished, self._jobs):
for job in filter(self.is_finished, self._model.jobs):
self.add(job)

def item_info(self, item: QListWidgetItem) -> tuple[str, int]: # job id, image index
return item.data(Qt.ItemDataRole.UserRole), item.data(Qt.ItemDataRole.UserRole + 1)

@property
def selected_job(self) -> Job | None:
items = self.selectedItems()
if len(items) > 0:
job_id, _ = self.item_info(items[0])
return self._model.jobs.find(job_id)
return None

def handle_preview_click(self, item: QListWidgetItem):
if item.text() != "" and item.text() != "<no prompt>":
if clipboard := QGuiApplication.clipboard():
Expand All @@ -212,7 +223,7 @@ def handle_preview_click(self, item: QListWidgetItem):

def mousePressEvent(self, e: QMouseEvent | None) -> None:
# make single click deselect current item (usually requires Ctrl+click)
if e is not None:
if e is not None and e.button() == Qt.MouseButton.LeftButton:
mods = e.modifiers()
mods |= Qt.KeyboardModifier.ControlModifier
e = QMouseEvent(
Expand Down Expand Up @@ -250,6 +261,24 @@ def _image_thumbnail(self, job: Job, index: int):
thumb.draw_image(self._applied_icon, offset=(-28, 4))
return thumb.to_icon()

def _show_context_menu(self, pos: QPoint):
item = self.itemAt(pos)
if item is not None:
menu = QMenu(self)
menu.addAction("Copy Prompt", self._copy_prompt)
menu.addAction("Copy Seed", self._copy_seed)
menu.exec(self.mapToGlobal(pos))

def _copy_prompt(self):
if job := self.selected_job:
self._model.prompt = job.params.prompt
self._model.negative_prompt = job.params.negative_prompt

def _copy_seed(self):
if job := self.selected_job:
self._model.fixed_seed = True
self._model.seed = job.params.seed


class GenerationWidget(QWidget):
_model: Model
Expand Down Expand Up @@ -347,7 +376,7 @@ def model(self, model: Model):
]
self.control_list.model = model
self.queue_button.model = model
self.history.jobs = model.jobs
self.history.model_ = model

def update_progress(self):
self.progress_bar.setValue(int(self.model.progress * 100))
Expand Down
Loading

0 comments on commit 1df9afa

Please sign in to comment.