Skip to content

Commit

Permalink
Support both Pydantic v1 and v2 (#65)
Browse files Browse the repository at this point in the history
* unpin to allow pydantic v2

* 🎨 set default value

* ⬆️ upgrade fastapi-websocket-rpc

* version bump

* ➕ add `packaging` as dependency

* 🎨 add util methods for supporting both pydantic 1 and 2

* use `get_model_serializer` to support pydantic v1 and 2

* 🎨 black formatting

* 🎨 imports

* 🎨 add helper method for printing model dict in pydantic v1 and 2

* 🎨 support kwargs in helper methods

* 🎨 helper method to print model dict

* revert version bump

* 🎨 update pydantic helper methods to return result in one call

* 🎨 update to use new pydantic helper methods
  • Loading branch information
ff137 authored Sep 20, 2023
1 parent aba5407 commit 9d07fc5
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 12 deletions.
19 changes: 12 additions & 7 deletions fastapi_websocket_pubsub/event_broadcaster.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
from typing import Any, Union
from pydantic.main import BaseModel
from .event_notifier import EventNotifier, Subscription, TopicList, ALL_TOPICS
from broadcaster import Broadcast
from typing import Any

from .logger import get_logger
from broadcaster import Broadcast
from fastapi_websocket_rpc.utils import gen_uid
from pydantic.main import BaseModel

from .event_notifier import ALL_TOPICS, EventNotifier, Subscription, TopicList
from .logger import get_logger
from .util import pydantic_serialize

logger = get_logger("EventBroadcaster")

Expand All @@ -18,7 +19,7 @@
class BroadcastNotification(BaseModel):
notifier_id: NotifierId
topics: TopicList
data: Any
data: Any = None


class EventBroadcasterException(Exception):
Expand Down Expand Up @@ -180,7 +181,9 @@ async def __broadcast_notifications__(self, subscription: Subscription, data):
async with self._broadcast_type(
self._broadcast_url
) as sharing_broadcast_channel:
await sharing_broadcast_channel.publish(self._channel, note.json())
await sharing_broadcast_channel.publish(
self._channel, pydantic_serialize(note)
)

async def _subscribe_to_all_topics(self):
return await self._notifier.subscribe(
Expand Down Expand Up @@ -277,8 +280,10 @@ async def __read_notifications__(self):
)

self._tasks.add(task)

def cleanup(task):
self._tasks.remove(task)

task.add_done_callback(cleanup)
except:
logger.exception("Failed handling incoming broadcast")
Expand Down
3 changes: 2 additions & 1 deletion fastapi_websocket_pubsub/event_notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel # pylint: disable=no-name-in-module

from .logger import get_logger
from .util import pydantic_to_dict

logger = get_logger("EventNotifier")

Expand Down Expand Up @@ -129,7 +130,7 @@ async def subscribe(
)
subscriptions.append(new_subscription)
new_subscriptions.append(new_subscription)
logger.debug(f"New subscription {new_subscription.dict()}")
logger.debug(f"New subscription {pydantic_to_dict(new_subscription)}")
await EventNotifier.trigger_events(
self._on_subscribe_events, subscriber_id, topics
)
Expand Down
7 changes: 6 additions & 1 deletion fastapi_websocket_pubsub/rpc_event_methods.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio

from fastapi_websocket_rpc import RpcMethodsBase

from .event_notifier import EventNotifier, Subscription, TopicList
from .logger import get_logger
from .util import pydantic_to_dict


class RpcEventServerMethods(RpcMethodsBase):
Expand All @@ -22,7 +25,9 @@ async def subscribe(self, topics: TopicList = []) -> bool:
async def callback(subscription: Subscription, data):
# remove the actual function
sub = subscription.copy(exclude={"callback"})
self.logger.info(f"Notifying other side: subscription={subscription.dict(exclude={'callback'})}, data={data}, channel_id={self.channel.id}")
self.logger.info(
f"Notifying other side: subscription={pydantic_to_dict(subscription, exclude={'callback'})}, data={data}, channel_id={self.channel.id}"
)
await self.channel.other.notify(subscription=sub, data=data)

if self._rpc_channel_get_remote_id:
Expand Down
21 changes: 21 additions & 0 deletions fastapi_websocket_pubsub/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pydantic
from packaging import version


# Helper methods for supporting Pydantic v1 and v2
def is_pydantic_pre_v2():
return version.parse(pydantic.VERSION) < version.parse("2.0.0")


def pydantic_serialize(model, **kwargs):
if is_pydantic_pre_v2():
return model.json(**kwargs)
else:
return model.model_dump_json(**kwargs)


def pydantic_to_dict(model, **kwargs):
if is_pydantic_pre_v2():
return model.dict(**kwargs)
else:
return model.model_dump(**kwargs)
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fastapi-websocket-rpc>=0.1.24,<1
fastapi-websocket-rpc>=0.1.25,<1
packaging>=20.4
permit-broadcaster[redis,postgres,kafka]>=0.2.5,<3
pydantic>=1.9.1,<2
websockets>=10.3,<11
pydantic>=1.9.1
websockets>=10.3,<11

0 comments on commit 9d07fc5

Please sign in to comment.