Skip to content

Commit

Permalink
Refactor the generation loop
Browse files Browse the repository at this point in the history
  • Loading branch information
modelflat committed Nov 26, 2024
1 parent 374b961 commit eebb0a1
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dataclasses import replace
from pathlib import Path
from enum import Enum
import time
from typing import Any, NamedTuple
from PyQt5.QtCore import QObject, QUuid, pyqtSignal, Qt
from PyQt5.QtGui import QImage, QPainter, QColor, QBrush
Expand Down Expand Up @@ -295,9 +296,10 @@ def estimate_cost(self, kind=JobKind.diffusion):
return 0

def generate_live(self):
eventloop.run(_report_errors(self, self._generate_live()))
input, job_params = self._prepare_live_workflow()
eventloop.run(_report_errors(self, self._generate_live(input, job_params)))

def _prepare_live_job_params(self):
def _prepare_live_workflow(self):
strength = self.live.strength
workflow_kind = WorkflowKind.generate if strength == 1.0 else WorkflowKind.refine
client = self._connection.client
Expand Down Expand Up @@ -349,13 +351,9 @@ def _prepare_live_job_params(self):
params = JobParams(bounds, conditioning.positive, regions=job_regions)
return input, params

async def _generate_live(self, last_input: WorkflowInput | None = None):
input, job_params = self._prepare_live_job_params()
if input != last_input:
self.clear_error()
await self.enqueue_jobs(input, JobKind.live_preview, job_params)
return input
return None
async def _generate_live(self, input: WorkflowInput, job_params: JobParams):
self.clear_error()
await self.enqueue_jobs(input, JobKind.live_preview, job_params)

async def _generate_custom(self, previous_input: WorkflowInput | None):
if self.workspace is not Workspace.custom or not self.document.is_active:
Expand Down Expand Up @@ -829,6 +827,7 @@ class LiveWorkspace(QObject, ObservableProperties):

_model: Model
_last_input: WorkflowInput | None = None
_last_change: float = 0
_result: Image | None = None
_result_composition: Image | None = None
_result_params: JobParams | None = None
Expand Down Expand Up @@ -875,21 +874,17 @@ def handle_job_finished(self, job: Job):
eventloop.run(_report_errors(self._model, self._continue_generating()))

async def _continue_generating(self):
just_got_here = True
while self.is_active and self._model.document.is_active:
new_input, _ = self._model._prepare_live_job_params()
if self._last_input != new_input:
if settings.live_redraw_grace_period > 0 and not just_got_here:
# only use grace period if this isn't our first frame of polling
# if it is, and there are changes in the input, it's likely that we have some changes we ignored
# previously due to the generation process running, and we need to update the preview asap
await asyncio.sleep(settings.live_redraw_grace_period)
new_input = await self._model._generate_live(self._last_input)
if new_input is not None:
self._last_input = new_input
return
# no changes in input data
just_got_here = False
while self.is_active:
if self._model.document.is_active:
new_input, job_params = self._model._prepare_live_workflow()
if self._last_input != new_input:
now = time.monotonic()
if self._last_change + settings.live_redraw_grace_period <= now:
await self._model._generate_live(new_input, job_params)
self._last_input = new_input
return
else:
self._last_change = time.monotonic()
await asyncio.sleep(self._poll_rate)

def apply_result(self, layer_only=False):
Expand Down

0 comments on commit eebb0a1

Please sign in to comment.