Skip to content

Commit

Permalink
Use __init_subclass__ instead of metaclass
Browse files Browse the repository at this point in the history
it's a bit simpler and doesn't confuse type checkers as much
  • Loading branch information
Acly committed Dec 28, 2023
1 parent 7a35beb commit 15822cb
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 34 deletions.
4 changes: 2 additions & 2 deletions ai_diffusion/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .client import Client, ClientMessage, ClientEvent, DeviceInfo, MissingResource
from .network import NetworkError
from .settings import Settings, PerformancePreset, settings
from .properties import Property, PropertyMeta
from .properties import Property, ObservableProperties
from . import util, eventloop


Expand All @@ -18,7 +18,7 @@ class ConnectionState(Enum):
error = 3


class Connection(QObject, metaclass=PropertyMeta):
class Connection(QObject, ObservableProperties):
state = Property(ConnectionState.disconnected)
error = Property("")
missing_resource: MissingResource | None = None
Expand Down
4 changes: 2 additions & 2 deletions ai_diffusion/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from .settings import settings
from .resources import ControlMode
from .client import resolve_sd_version
from .properties import Property, PropertyMeta
from .properties import Property, ObservableProperties
from .image import Bounds
from .workflow import Control


class ControlLayer(QObject, metaclass=PropertyMeta):
class ControlLayer(QObject, ObservableProperties):
mode = Property(ControlMode.image, persist=True)
layer_id = Property(QUuid(), persist=True)
strength = Property(100, persist=True)
Expand Down
15 changes: 7 additions & 8 deletions ai_diffusion/model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from __future__ import annotations
import asyncio
import random
from enum import Enum
from typing import NamedTuple, cast
from typing import NamedTuple
from PyQt5.QtCore import QObject, pyqtSignal

from . import eventloop, workflow, util
from .settings import settings
from .network import NetworkError
from .image import Extent, Image, ImageCollection, Mask, Bounds
from .image import Extent, Image, Mask, Bounds
from .client import ClientMessage, ClientEvent, filter_supported_styles, resolve_sd_version
from .document import Document, LayerObserver
from .pose import Pose
from .style import Style, Styles, SDVersion
from .workflow import ControlMode, Conditioning, LiveParams
from .connection import Connection, ConnectionState
from .properties import Property, PropertyMeta
from .connection import Connection
from .properties import Property, ObservableProperties
from .jobs import Job, JobKind, JobQueue, JobState
from .control import ControlLayer, ControlLayerList
import krita
Expand All @@ -27,7 +26,7 @@ class Workspace(Enum):
live = 2


class Model(QObject, metaclass=PropertyMeta):
class Model(QObject, ObservableProperties):
"""Represents diffusion workflows for a specific Krita document. Stores all inputs related to
image generation. Launches generation jobs. Listens to server messages and keeps a
list of finished, currently running and enqueued jobs.
Expand Down Expand Up @@ -364,7 +363,7 @@ class UpscaleParams(NamedTuple):
target_extent: Extent


class UpscaleWorkspace(QObject, metaclass=PropertyMeta):
class UpscaleWorkspace(QObject, ObservableProperties):
upscaler = Property("", persist=True)
factor = Property(2.0, persist=True)
use_diffusion = Property(True, persist=True)
Expand Down Expand Up @@ -401,7 +400,7 @@ def params(self):
)


class LiveWorkspace(QObject, metaclass=PropertyMeta):
class LiveWorkspace(QObject, ObservableProperties):
is_active = Property(False, setter="toggle")
strength = Property(0.3, persist=True)
seed = Property(0, persist=True)
Expand Down
33 changes: 16 additions & 17 deletions ai_diffusion/properties.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,30 @@
from enum import Enum
from typing import Any, NamedTuple, Sequence, TypeVar, Generic

from PyQt5.QtCore import QObject, QMetaObject, QUuid, pyqtBoundSignal, pyqtProperty # type: ignore
from PyQt5.QtCore import QObject, QMetaObject, QUuid, pyqtBoundSignal
from PyQt5.QtWidgets import QComboBox


T = TypeVar("T")


class PropertyMeta(type(QObject)):
"""Provides default implementations for properties (get, set, signal)."""
class ObservableProperties:
"""Provides default implementations for properties (get, set, signal) to sub-classes."""

def __new__(cls, name, bases, attrs):
for key in list(attrs.keys()):
attr = attrs[key]
if not isinstance(attr, Property):
continue
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)

attrs[f"_{key}"] = attr.default_value
properties = {
name: attr for name, attr in cls.__dict__.items() if isinstance(attr, Property)
}
for name, property in properties.items():
setattr(cls, f"_{name}", property.default_value)
getter, setter = None, None
if attr.getter is not None:
getter = attrs[attr.getter]
if attr.setter is not None:
setter = attrs[attr.setter]
attrs[key] = PropertyImpl(key, getter, setter, attr.persist)

return super().__new__(cls, name, bases, attrs)
if property.getter is not None:
getter = getattr(cls, property.getter)
if property.setter is not None:
setter = getattr(cls, property.setter)
setattr(cls, name, PropertyImpl(name, getter, setter, property.persist))


class Property(Generic[T]):
Expand All @@ -48,7 +47,7 @@ def __delete__(self, instance): ...


class PropertyImpl(property):
"""Property implementation: gets, sets, and notifies of change."""
"""Property implementation: gets/sets a value, and emits a signal when it changes."""

name: str
persist: bool
Expand Down
17 changes: 13 additions & 4 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from enum import Enum
import pytest
from PyQt5.QtCore import QObject, pyqtBoundSignal, pyqtSignal
from PyQt5.QtCore import QObject, pyqtSignal

from ai_diffusion.properties import Property, PropertyMeta, bind, serialize, deserialize
from ai_diffusion.properties import Property, ObservableProperties, bind, serialize, deserialize


class Piong(Enum):
a = 1
b = 2


class ObjectWithProperties(QObject, metaclass=PropertyMeta):
class ObjectWithProperties(QObject, ObservableProperties):
inty = Property(0)
stringy = Property("")
enumy = Property(Piong.a)
Expand Down Expand Up @@ -73,6 +73,15 @@ def callback(x):
assert called == [42, "hello", Piong.b, 5, 55]


def test_multiple():
a = ObjectWithProperties()
b = ObjectWithProperties()

a.inty = 5
b.inty = 99
assert a.inty != b.inty


def test_bind():
a = ObjectWithProperties()
b = ObjectWithProperties()
Expand Down Expand Up @@ -112,7 +121,7 @@ def test_bind_qt_to_property():
assert a.qtstyle() == 99


class PersistentObject(QObject, metaclass=PropertyMeta):
class PersistentObject(QObject, ObservableProperties):
inty = Property(0, persist=True)
stringy = Property("", persist=True)
enumy = Property(Piong.a, persist=True)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,8 @@ def test_create_control_image(qtapp, comfy, mode):
async def main():
result = await run_and_save(comfy, job, image_name)
reference = Image.load(reference_dir / image_name)
assert Image.compare(result, reference) < 0.002
threshold = 0.005 if mode is ControlMode.pose else 0.002
assert Image.compare(result, reference) < threshold

qtapp.run(main())

Expand Down

0 comments on commit 15822cb

Please sign in to comment.