diff --git a/ai_diffusion/connection.py b/ai_diffusion/connection.py index 29dc9ac24..7a18fa734 100644 --- a/ai_diffusion/connection.py +++ b/ai_diffusion/connection.py @@ -7,7 +7,7 @@ from .client import Client, ClientMessage, ClientEvent, DeviceInfo, MissingResource from .network import NetworkError from .settings import Settings, PerformancePreset, settings -from .properties import Property, PropertyMeta +from .properties import Property, ObservableProperties from . import util, eventloop @@ -18,7 +18,7 @@ class ConnectionState(Enum): error = 3 -class Connection(QObject, metaclass=PropertyMeta): +class Connection(QObject, ObservableProperties): state = Property(ConnectionState.disconnected) error = Property("") missing_resource: MissingResource | None = None diff --git a/ai_diffusion/control.py b/ai_diffusion/control.py index 40492aa75..396c51b95 100644 --- a/ai_diffusion/control.py +++ b/ai_diffusion/control.py @@ -5,12 +5,12 @@ from .settings import settings from .resources import ControlMode from .client import resolve_sd_version -from .properties import Property, PropertyMeta +from .properties import Property, ObservableProperties from .image import Bounds from .workflow import Control -class ControlLayer(QObject, metaclass=PropertyMeta): +class ControlLayer(QObject, ObservableProperties): mode = Property(ControlMode.image, persist=True) layer_id = Property(QUuid(), persist=True) strength = Property(100, persist=True) diff --git a/ai_diffusion/model.py b/ai_diffusion/model.py index 5f10bba50..160bfc165 100644 --- a/ai_diffusion/model.py +++ b/ai_diffusion/model.py @@ -1,21 +1,20 @@ from __future__ import annotations -import asyncio import random from enum import Enum -from typing import NamedTuple, cast +from typing import NamedTuple from PyQt5.QtCore import QObject, pyqtSignal from . import eventloop, workflow, util from .settings import settings from .network import NetworkError -from .image import Extent, Image, ImageCollection, Mask, Bounds +from .image import Extent, Image, Mask, Bounds from .client import ClientMessage, ClientEvent, filter_supported_styles, resolve_sd_version from .document import Document, LayerObserver from .pose import Pose from .style import Style, Styles, SDVersion from .workflow import ControlMode, Conditioning, LiveParams -from .connection import Connection, ConnectionState -from .properties import Property, PropertyMeta +from .connection import Connection +from .properties import Property, ObservableProperties from .jobs import Job, JobKind, JobQueue, JobState from .control import ControlLayer, ControlLayerList import krita @@ -27,7 +26,7 @@ class Workspace(Enum): live = 2 -class Model(QObject, metaclass=PropertyMeta): +class Model(QObject, ObservableProperties): """Represents diffusion workflows for a specific Krita document. Stores all inputs related to image generation. Launches generation jobs. Listens to server messages and keeps a list of finished, currently running and enqueued jobs. @@ -364,7 +363,7 @@ class UpscaleParams(NamedTuple): target_extent: Extent -class UpscaleWorkspace(QObject, metaclass=PropertyMeta): +class UpscaleWorkspace(QObject, ObservableProperties): upscaler = Property("", persist=True) factor = Property(2.0, persist=True) use_diffusion = Property(True, persist=True) @@ -401,7 +400,7 @@ def params(self): ) -class LiveWorkspace(QObject, metaclass=PropertyMeta): +class LiveWorkspace(QObject, ObservableProperties): is_active = Property(False, setter="toggle") strength = Property(0.3, persist=True) seed = Property(0, persist=True) diff --git a/ai_diffusion/properties.py b/ai_diffusion/properties.py index 3686509ef..40a07a8d1 100644 --- a/ai_diffusion/properties.py +++ b/ai_diffusion/properties.py @@ -1,31 +1,30 @@ from enum import Enum from typing import Any, NamedTuple, Sequence, TypeVar, Generic -from PyQt5.QtCore import QObject, QMetaObject, QUuid, pyqtBoundSignal, pyqtProperty # type: ignore +from PyQt5.QtCore import QObject, QMetaObject, QUuid, pyqtBoundSignal from PyQt5.QtWidgets import QComboBox T = TypeVar("T") -class PropertyMeta(type(QObject)): - """Provides default implementations for properties (get, set, signal).""" +class ObservableProperties: + """Provides default implementations for properties (get, set, signal) to sub-classes.""" - def __new__(cls, name, bases, attrs): - for key in list(attrs.keys()): - attr = attrs[key] - if not isinstance(attr, Property): - continue + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) - attrs[f"_{key}"] = attr.default_value + properties = { + name: attr for name, attr in cls.__dict__.items() if isinstance(attr, Property) + } + for name, property in properties.items(): + setattr(cls, f"_{name}", property.default_value) getter, setter = None, None - if attr.getter is not None: - getter = attrs[attr.getter] - if attr.setter is not None: - setter = attrs[attr.setter] - attrs[key] = PropertyImpl(key, getter, setter, attr.persist) - - return super().__new__(cls, name, bases, attrs) + if property.getter is not None: + getter = getattr(cls, property.getter) + if property.setter is not None: + setter = getattr(cls, property.setter) + setattr(cls, name, PropertyImpl(name, getter, setter, property.persist)) class Property(Generic[T]): @@ -48,7 +47,7 @@ def __delete__(self, instance): ... class PropertyImpl(property): - """Property implementation: gets, sets, and notifies of change.""" + """Property implementation: gets/sets a value, and emits a signal when it changes.""" name: str persist: bool diff --git a/tests/test_properties.py b/tests/test_properties.py index e21aefc5e..78dbc9ef0 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -1,8 +1,8 @@ from enum import Enum import pytest -from PyQt5.QtCore import QObject, pyqtBoundSignal, pyqtSignal +from PyQt5.QtCore import QObject, pyqtSignal -from ai_diffusion.properties import Property, PropertyMeta, bind, serialize, deserialize +from ai_diffusion.properties import Property, ObservableProperties, bind, serialize, deserialize class Piong(Enum): @@ -10,7 +10,7 @@ class Piong(Enum): b = 2 -class ObjectWithProperties(QObject, metaclass=PropertyMeta): +class ObjectWithProperties(QObject, ObservableProperties): inty = Property(0) stringy = Property("") enumy = Property(Piong.a) @@ -73,6 +73,15 @@ def callback(x): assert called == [42, "hello", Piong.b, 5, 55] +def test_multiple(): + a = ObjectWithProperties() + b = ObjectWithProperties() + + a.inty = 5 + b.inty = 99 + assert a.inty != b.inty + + def test_bind(): a = ObjectWithProperties() b = ObjectWithProperties() @@ -112,7 +121,7 @@ def test_bind_qt_to_property(): assert a.qtstyle() == 99 -class PersistentObject(QObject, metaclass=PropertyMeta): +class PersistentObject(QObject, ObservableProperties): inty = Property(0, persist=True) stringy = Property("", persist=True) enumy = Property(Piong.a, persist=True) diff --git a/tests/test_workflow.py b/tests/test_workflow.py index fd232f0f2..ca0c86551 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -403,7 +403,8 @@ def test_create_control_image(qtapp, comfy, mode): async def main(): result = await run_and_save(comfy, job, image_name) reference = Image.load(reference_dir / image_name) - assert Image.compare(result, reference) < 0.002 + threshold = 0.005 if mode is ControlMode.pose else 0.002 + assert Image.compare(result, reference) < threshold qtapp.run(main())