Skip to content

Commit

Permalink
migrate dms3 to pydantic v2 (forgotten)
Browse files Browse the repository at this point in the history
  • Loading branch information
eudoxos committed Jul 1, 2024
1 parent 4af97fb commit 4e0eeec
Showing 1 changed file with 41 additions and 41 deletions.
82 changes: 41 additions & 41 deletions mupifDB/api/edm/dms3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pydantic
from typing import Literal,Optional,Union,Tuple,List,Set,Dict,Any
from typing import Literal,Optional,Union,Tuple,List,Set,Dict,Any,Self
import json
import typing
import re
Expand All @@ -20,6 +20,9 @@
import attrdict
import io
import logging
from pydantic import Field, ConfigDict
from typing_extensions import Annotated

log=logging.getLogger(__name__)
# logging.basicConfig(format='{levelname:7}: {message}', style='{', level=logging.INFO)

Expand Down Expand Up @@ -58,14 +61,14 @@ class ItemSchema(pydantic.BaseModel):
'One data item in database schema'
dtype: Literal['f','i','?','str','bytes','object']='f'
unit: Optional[str]=None
shape: pydantic.conlist(item_type=int,min_items=0,max_items=5)=[]
shape: Annotated[List[int], Field(min_length=0,max_length=5)]=[]
link: Optional[str]=None
implicit: Dict[str,Union[str,int]]=pydantic.Field(default_factory=dict)

def is_a_quantity(self):
return self.dtype in ('f','i','?')

@pydantic.validator("unit")
@pydantic.field_validator("unit")
def unit_valid(cls,v):
if v is None: return
try: au.Unit(v)
Expand All @@ -74,30 +77,28 @@ def unit_valid(cls,v):
raise
return v

@pydantic.root_validator
def dtype_shape(cls,attrs):
dtype,shape=attrs['dtype'],attrs['shape']
if dtype=='bytes' and shape!=[]: raise ValueError('bytes must have no shape specified')
if dtype=='objects' and shape!=[]: raise ValueError('objects must have no shape specified')
return attrs

@pydantic.root_validator
def link_shape(cls,attrs):
if attrs['link'] is not None:
if len(attrs['shape'])>1: raise ValueError('links must be either scalar (shape=[]) or 1d array (shape=[num]).')
if attrs['unit'] is not None: raise ValueError('unit not permitted with links')
if 'implicit' in attrs and len(attrs['implicit'])>0: raise ValueError('implicit values not permitted with links')
return attrs


class SchemaSchema(pydantic.BaseModel):
'Schema of the schema itself; read via parse_obj'
__root__: typing.Dict[str,typing.Dict[str,ItemSchema]]

@pydantic.root_validator(pre=False)
def _links_valid(cls, values):
root=values['__root__']
for T,fields in root.items():
@pydantic.model_validator(mode='after')
def dtype_shape(self) -> Self:
if self.dtype=='bytes' and self.shape!=[]: raise ValueError('bytes must have no shape specified')
if self.dtype=='objects' and self.shape!=[]: raise ValueError('objects must have no shape specified')
return self

@pydantic.model_validator(mode='after')
def link_shape(self) -> Self:
if self.link is not None:
if len(self.shape)>1: raise ValueError('links must be either scalar (shape=[]) or 1d array (shape=[num]).')
if self.unit is not None: raise ValueError('unit not permitted with links')
if self.implicit and len(self.implicit)>0: raise ValueError('implicit values not permitted with links')
return self


class SchemaSchema(pydantic.RootModel):
'Schema of the schema itself; read via model_validate'
root: typing.Dict[str,typing.Dict[str,ItemSchema]]

@pydantic.model_validator(mode='after')
def _links_valid(self):
for T,fields in self.root.items():
# must handle repeated validation
if 'meta' in fields:
if fields['meta']!=ItemSchema(dtype='object'): raise ValueError(f'{T}: "meta" field may not be specified in schema (is added automatically).')
Expand All @@ -106,8 +107,8 @@ def _links_valid(cls, values):
fields['meta']=ItemSchema(dtype='object')
for f,i in fields.items():
if i.link is None: continue
if i.link not in root.keys(): raise ValueError(f'{T}.{f}: link to undefined collection {i.link}.')
return values
if i.link not in self.root.keys(): raise ValueError(f'{T}.{f}: link to undefined collection {i.link}.')
return self

class StrModel(pydantic.BaseModel):
value: Union[str,List[str],List[List[str]],List[List[List[str]]],List[List[List[List[str]]]]]
Expand Down Expand Up @@ -225,8 +226,7 @@ def _validated_quantity(itemName: str, item: ItemSchema, data):


class _PathEntry(pydantic.BaseModel):
class Config:
allow_mutation = False
model_config = ConfigDict(frozen=True)

attr: str
index: Optional[int]=None
Expand Down Expand Up @@ -320,7 +320,7 @@ def mk_entry(v):
def _unparse_path(path: [(str,Optional[int])]):
return '.'.join([ent.to_str() for ent in path])

@pydantic.validate_arguments(config=dict(arbitrary_types_allowed=True))
@pydantic.validate_call(config=dict(arbitrary_types_allowed=True))
def _quantity_to_dict(q: Union[np.ndarray,au.Quantity]) -> dict:
if isinstance(q,au.Quantity): return {'value':q.value.tolist(),'unit':str(q.unit)}
return {'value':q.tolist()}
Expand Down Expand Up @@ -465,7 +465,7 @@ def dms_api_schema_get(db: str, include_id:bool=False):

@router.get('/{db}/schema/graphviz')
def dms_api_schema_graphviz(db: str):
sch=GG.schema_get(db,include_id=False).__root__ # root convertes from SchemaSchema to dict
sch=GG.schema_get(db,include_id=False).root # root convertes from SchemaSchema to dict
graph='digraph g {\n graph [rankdir="LR"];\n'
links=[]
for klass in sch:
Expand Down Expand Up @@ -532,10 +532,10 @@ def _api_value_to_db_rec__attr(item,val,prefix):
if val[k]!=v: raise ValueError(f'{prefix}: implicit field {k} has different value in schema and data ({v} vs. {val[k]})')
del val[k]
if item.dtype=='str':
s=StrModel.parse_obj(val)
s=StrModel.model_validate(val)
s.schema_check(prefix,item)
return s.dict()
elif item.dtype=='bytes': return BytesModel.parse_obj(val).dict()
return s.model_dump()
elif item.dtype=='bytes': return BytesModel.model_validate(val).model_dump()
elif item.dtype=='object':
return json.loads(json.dumps(val))
elif item.is_a_quantity():
Expand Down Expand Up @@ -774,7 +774,7 @@ def _resolve(*,obj,index,i=item,key=key):

@router.get('/{db}')
def dms_api_type_list(db: str):
return list(GG.schema_get(db).dict().keys())
return list(GG.schema_get(db).model_dump().keys())

@router.get('/{db}/{type}')
def dms_api_object_list(db: str, type: str):
Expand Down Expand Up @@ -818,13 +818,13 @@ def schema_get(db:str,include_id:bool=False):
rawSchema=GG.db_get(db)['schema'].find_one()
if rawSchema is not None:
if not include_id and '_id' in rawSchema: del rawSchema['_id'] # this prevents breakage when reloading
GG._SCH[db]=SchemaSchema.parse_obj(rawSchema)
GG._SCH[db]=SchemaSchema.model_validate(rawSchema)
return GG._SCH[db]

@staticmethod
def schema_get_type(db:str,type:str):
# why do we need to access through __root__ here? unclear
return GG.schema_get(db).__root__[type]
# why do we need to access through root here? unclear
return GG.schema_get(db).root[type]

@staticmethod
def schema_invalidate_cache():
Expand All @@ -833,7 +833,7 @@ def schema_invalidate_cache():
#@pydantic.validate_arguments
@staticmethod
def schema_import(db:str, json_str:str, force=False):
schema=SchemaSchema.parse_obj(json.loads(json_str))
schema=SchemaSchema.model_validate(json.loads(json_str))
dms_api_schema_post(db,schema,force=force)
GG.schema_invalidate_cache()

Expand Down

0 comments on commit 4e0eeec

Please sign in to comment.