From db69912f14cd992da2cd46bfe9bb91d64ba156ad Mon Sep 17 00:00:00 2001
From: Tomasz Kulik <tomek.k@confio.gmbh>
Date: Fri, 22 Nov 2024 15:42:51 +0100
Subject: [PATCH] chore: Replace dataclasses and dataclasses_json with pydantic

---
 .../playground/playground.py                  | 112 ++++++------------
 1 file changed, 38 insertions(+), 74 deletions(-)

diff --git a/packages/cw-schema-codegen/playground/playground.py b/packages/cw-schema-codegen/playground/playground.py
index ac850f87d..adde5c356 100644
--- a/packages/cw-schema-codegen/playground/playground.py
+++ b/packages/cw-schema-codegen/playground/playground.py
@@ -1,84 +1,48 @@
-from dataclasses import dataclass, field
-from dataclasses_json import dataclass_json, config
-from typing import Optional, Iterable
 import sys
-import json
+from typing import Literal, Union, Tuple
+from pydantic import BaseModel, RootModel
 
 
-# TODO tkulik: try to get rid of the `dataclasses_json` dependency
+class Field1(RootModel):
+    root: Literal['Field1']
 
+class Field2(BaseModel):
+    Field2: Tuple[int, int]
 
-enum_field = lambda: field(default=None, metadata=config(exclude=lambda x: x is None))
+class Field3_Struct(BaseModel):
+    a: str
+    b: int
 
-@dataclass_json
-@dataclass
-class SomeEnum:
-    class VariantIndicator:
-        pass
+class Field3(BaseModel):
+    Field3: Field3_Struct
 
-    class Field3Type:
-        a: str
-        b: int
+class Field4(BaseModel):
+    Field4: 'SomeEnum'
 
-    class Field5Type:
-        a: Iterable['SomeEnum']
+class Field5_Struct(BaseModel):
+    a: 'SomeEnum'
 
-    Field1: Optional[VariantIndicator] = enum_field()
-    Field2: Optional[tuple[int, int]] = enum_field()
-    Field3: Optional[Field3Type] = enum_field()
-    Field4: Optional[Iterable['SomeEnum']] = enum_field()
-    Field5: Optional[Field5Type] = enum_field()
-    
-    def deserialize(json):
-        if not ":" in json:
-            if json == '"Field1"':
-                return SomeEnum(Field1=SomeEnum.VariantIndicator())
-            else:
-                raise Exception(f"Deserialization error, undefined variant: {json}")
-        else:
-            return SomeEnum.from_json(json)
-        
-    def serialize(self):
-        if self.Field1 is not None:
-            return '"Field1"'
-        else:
-            return SomeEnum.to_json(self)
-        
-@dataclass_json
-@dataclass
-class UnitStructure:
-    def deserialize(json):
-        if json == "null":
-            return UnitStructure()
-        else:
-            Exception(f"Deserialization error, undefined value: {json}")
-        
-    def serialize(self):
-        return 'null'
-
-@dataclass_json
-@dataclass
-class TupleStructure:
-    Tuple: tuple[int, str, int]
-
-    def deserialize(json):
-        return TupleStructure.from_json(f'{{ "Tuple": {json} }}')
-        
-    def serialize(self):
-        return json.dumps(self.Tuple)
-
-@dataclass_json
-@dataclass
-class NamedStructure:
+class Field5(BaseModel):
+    Field5: Field5_Struct
+
+class SomeEnum(RootModel):
+    root: Union[Field1, Field2, Field3, Field4, Field5]
+
+
+class UnitStructure(RootModel[None]):
+    pass
+
+
+class TupleStructure(RootModel):
+    root: tuple[int, str, int]
+
+
+class NamedStructure(BaseModel):
     a: str
     b: int
-    c: Iterable['SomeEnum']
+    c: SomeEnum
+
 
-    def deserialize(json):
-        return NamedStructure.from_json(json)
-        
-    def serialize(self):
-        return self.to_json()
 
 ###
 ### TESTS:
@@ -88,16 +52,16 @@ def serialize(self):
     input = input.rstrip()
     try:
         if index < 5:
-            deserialized = SomeEnum.deserialize(input)
+            deserialized = SomeEnum.model_validate_json(input)
         elif index == 5:
-            deserialized = UnitStructure.deserialize(input)
+            deserialized = UnitStructure.model_validate_json(input)
         elif index == 6:
-            deserialized = TupleStructure.deserialize(input)
+            deserialized = TupleStructure.model_validate_json(input)
         else:
-            deserialized = NamedStructure.deserialize(input)
+            deserialized = NamedStructure.model_validate_json(input)
     except:
         raise(Exception(f"This json can't be deserialized: {input}"))
-    serialized = deserialized.serialize()
+    serialized = deserialized.model_dump_json()
     print(serialized)