diff --git a/bigraph_schema/data.py b/bigraph_schema/data.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/bigraph_schema/data.py @@ -0,0 +1 @@ + diff --git a/bigraph_schema/registry.py b/bigraph_schema/registry.py index dbfc27a..b63ba6e 100644 --- a/bigraph_schema/registry.py +++ b/bigraph_schema/registry.py @@ -13,9 +13,12 @@ import numpy as np from pprint import pformat as pf +from typing import Any, Tuple, Union, Optional, Mapping +from dataclasses import field, make_dataclass from bigraph_schema.parse import parse_expression from bigraph_schema.protocols import local_lookup_module, function_module +import bigraph_schema.data NONE_SYMBOL = '!nil' @@ -525,6 +528,34 @@ def fold_union(schema, state, method, values, core): return result +def type_parameters_for(schema): + parameters = [] + for key in schema['_type_parameters']: + subschema = schema.get(f'_{key}', 'any') + parameters.append(subschema) + + return parameters + + +def dataclass_union(schema, path, core): + parameters = type_parameters_for(schema) + subtypes = [] + for parameter in parameters: + dataclass = core.dataclass( + parameter, + path) + + if isinstance(dataclass, str): + subtypes.append(dataclass) + elif isinstance(dataclass, type): + subtypes.append(dataclass.__name__) + else: + subtypes.append(str(dataclass)) + + parameter_block = ', '.join(subtypes) + return eval(f'Union[{parameter_block}]') + + def divide_any(schema, state, values, core): divisions = values.get('divisions', 2) @@ -551,6 +582,64 @@ def divide_any(schema, state, values, core): for _ in range(divisions)] +def dataclass_any(schema, path, core): + parts = path + if not parts: + parts = ['top'] + dataclass_name = '_'.join(parts) + + if isinstance(schema, dict): + type_name = schema.get('_type', 'any') + + branches = {} + for key, subschema in schema.items(): + if not key.startswith('_'): + branch = core.dataclass( + subschema, + path + [key]) + + def default(subschema=subschema): + return core.default(subschema) + + branches[key] = ( + key, + branch, + field(default_factory=default)) + + dataclass = make_dataclass( + dataclass_name, + branches.values(), + namespace={ + '__module__': 'bigraph_schema.data'}) + + setattr( + bigraph_schema.data, + dataclass_name, + dataclass) + + else: + schema = core.access(schema) + dataclass = core.dataclass(schema, path) + + return dataclass + + +def dataclass_tuple(schema, path, core): + parameters = type_parameters_for(schema) + subtypes = [] + + for index, key in enumerate(schema['type_parameters']): + subschema = schema.get(key, 'any') + subtype = core.dataclass( + subschema, + path + [index]) + + subtypes.append(subtype) + + parameter_block = ', '.join(subtypes) + return eval(f'tuple[{parameter_block}]') + + def divide_tuple(schema, state, values, core): divisions = values.get('divisions', 2) @@ -954,6 +1043,7 @@ def deserialize_union(schema, encoded, core): '_check': check_any, '_serialize': serialize_any, '_deserialize': deserialize_any, + '_dataclass': dataclass_any, '_fold': fold_any, '_merge': merge_any, '_bind': bind_any, @@ -967,6 +1057,7 @@ def deserialize_union(schema, encoded, core): '_slice': slice_tuple, '_serialize': serialize_tuple, '_deserialize': deserialize_tuple, + '_dataclass': dataclass_tuple, '_fold': fold_tuple, '_divide': divide_tuple, '_bind': bind_tuple, @@ -980,6 +1071,7 @@ def deserialize_union(schema, encoded, core): '_slice': slice_union, '_serialize': serialize_union, '_deserialize': deserialize_union, + '_dataclass': dataclass_union, '_fold': fold_union, '_description': 'union of a set of possible types'}} diff --git a/bigraph_schema/type_system.py b/bigraph_schema/type_system.py index 87c1217..690bce0 100644 --- a/bigraph_schema/type_system.py +++ b/bigraph_schema/type_system.py @@ -9,12 +9,15 @@ import pprint import pytest import random +import typing import inspect import numbers import numpy as np from pint import Quantity from pprint import pformat as pf +from typing import Any, Tuple, Union, Optional, Mapping, Callable, NewType, get_origin, get_args +from dataclasses import asdict from bigraph_schema.units import units, render_units_type from bigraph_schema.react import react_divide_counts @@ -28,6 +31,7 @@ remove_omitted ) +import bigraph_schema.data as data TYPE_SCHEMAS = { @@ -438,8 +442,10 @@ def react(self, schema, state, reaction, mode='random'): make_reaction = self.react_registry.access( reaction_key) react = make_reaction( - reaction.get( - reaction_key, {})) + schema, + state, + reaction.get(reaction_key, {}), + self) redex = react.get('redex', {}) reactum = react.get('reactum', {}) @@ -451,6 +457,12 @@ def react(self, schema, state, reaction, mode='random'): redex, mode=mode) + # for path in paths: + # path_schema, path_state = self.slice( + # schema, + # state, + # path) + def merge_state(before): remaining = remove_omitted( redex, @@ -470,6 +482,21 @@ def merge_state(before): return state + # TODO: maybe all fields are optional? + def dataclass(self, schema, path=None): + path = path or [] + + dataclass_function = self.choose_method( + schema, + {}, + 'dataclass') + + return dataclass_function( + schema, + path, + self) + + def check_state(self, schema, state): schema = self.access(schema) @@ -954,33 +981,38 @@ def project(self, ports, wires, path, states): wires = [wires] if isinstance(wires, (list, tuple)): - destination = list(path) + list(wires) + destination = resolve_path(list(path) + list(wires)) result = set_path( result, destination, states) elif isinstance(wires, dict): - branches = [] - for key in wires.keys(): - subports, substates = self.slice(ports, states, key) - projection = self.project( - subports, - wires[key], - path, - substates) + if isinstance(states, list): + result = [ + self.project(ports, wires, path, state) + for state in states] + else: + branches = [] + for key in wires.keys(): + subports, substates = self.slice(ports, states, key) + projection = self.project( + subports, + wires[key], + path, + substates) - if projection is not None: - branches.append(projection) + if projection is not None: + branches.append(projection) - branches = [ - branch - for branch in branches - if branch is not None] # and list(branch)[0][1] is not None] + branches = [ + branch + for branch in branches + if branch is not None] # and list(branch)[0][1] is not None] - result = {} - for branch in branches: - deep_merge(result, branch) + result = {} + for branch in branches: + deep_merge(result, branch) else: raise Exception( f'inverting state\n {states}\naccording to ports schema\n {ports}\nbut wires are not recognized\n {wires}') @@ -991,8 +1023,8 @@ def project(self, ports, wires, path, states): def project_edge(self, schema, instance, edge_path, states, ports_key='outputs'): """ Given states from the perspective of an edge (through its ports), produce states aligned to the tree - the wires point to. - (inverse of view) + the wires point to. + (inverse of view) """ if schema is None: @@ -1221,9 +1253,6 @@ def infer_wires(self, ports, state, wires, top_schema=None, path=None): current, port_schema) - if isinstance(destination, tuple): - import ipdb; ipdb.set_trace() - destination[destination_key] = self.access( port_schema) @@ -1529,6 +1558,14 @@ def concatenate(schema, current, update, core=None): return current + update +def dataclass_float(schema, path, core): + return float + + +def dataclass_integer(schema, path, core): + return int + + def divide_float(schema, state, values, core): divisions = values.get('divisions', 2) portion = float(state) / divisions @@ -1659,6 +1696,18 @@ def check_list(schema, state, core): return False +def dataclass_list(schema, path, core): + element_type = core.find_parameter( + schema, + 'element') + + dataclass = core.dataclass( + element_type, + path + ['element']) + + return list[dataclass] + + def slice_list(schema, state, path, core): element_type = core.find_parameter( schema, @@ -1724,6 +1773,25 @@ def check_tree(schema, state, core): return core.check(leaf_type, state) +def dataclass_tree(schema, path, core): + leaf_type = core.find_parameter( + schema, + 'leaf') + + leaf_dataclass = core.dataclass( + leaf_type, + path + ['leaf']) + + dataclass_name = '_'.join(path) + # block = f"{dataclass_name} = NewType('{dataclass_name}', Union[{leaf_dataclass}, Mapping[str, '{dataclass_name}']])" + block = f"NewType('{dataclass_name}', Union[{leaf_dataclass}, Mapping[str, '{dataclass_name}']])" + + dataclass = eval(block) + setattr(data, dataclass_name, dataclass) + + return dataclass + + def slice_tree(schema, state, path, core): leaf_type = core.find_parameter( schema, @@ -1817,6 +1885,18 @@ def apply_map(schema, current, update, core=None): return result +def dataclass_map(schema, path, core): + value_type = core.find_parameter( + schema, + 'value') + + dataclass = core.dataclass( + value_type, + path + ['value']) + + return Mapping[str, dataclass] + + def check_map(schema, state, core=None): value_type = core.find_parameter( schema, @@ -1893,6 +1973,18 @@ def apply_maybe(schema, current, update, core): update) +def dataclass_maybe(schema, path, core): + value_type = core.find_parameter( + schema, + 'value') + + dataclass = core.dataclass( + value_type, + path + ['value']) + + return Optional[dataclass] + + def check_maybe(schema, state, core): if state is None: return True @@ -1984,6 +2076,20 @@ def apply_edge(schema, current, update, core): return result +def dataclass_edge(schema, path, core): + inputs = schema.get('_inputs', {}) + inputs_dataclass = core.dataclass( + inputs, + path + ['inputs']) + + outputs = schema.get('_outputs', {}) + outputs_dataclass = core.dataclass( + outputs, + path + ['outputs']) + + return Callable[[inputs_dataclass], outputs_dataclass] + + def check_ports(state, core, key): return key in state and core.check( 'wires', @@ -2021,6 +2127,10 @@ def check_array(schema, state, core): +def dataclass_array(schema, path, core): + return np.ndarray + + def slice_array(schema, state, path, core): if len(path) > 0: head = path[0] @@ -2058,16 +2168,16 @@ def serialize_array(schema, value, core): if isinstance(value, dict): return value else: - data = 'string' + array_data = 'string' dtype = value.dtype.name if dtype.startswith('int'): - data = 'integer' + array_data = 'integer' elif dtype.startswith('float'): - data = 'float' + array_data = 'float' return { 'bytes': value.tobytes(), - 'data': data, + 'data': array_data, 'shape': value.shape} @@ -2339,6 +2449,14 @@ def register_units(core, units): +def dataclass_boolean(schema, path, core): + return bool + + +def dataclass_string(schema, path, core): + return str + + base_type_library = { 'boolean': { '_type': 'boolean', @@ -2347,7 +2465,7 @@ def register_units(core, units): '_apply': apply_boolean, '_serialize': serialize_boolean, '_deserialize': deserialize_boolean, - }, + '_dataclass': dataclass_boolean}, # abstract number type 'number': { @@ -2363,6 +2481,7 @@ def register_units(core, units): # inherit _apply and _serialize from number type '_check': check_integer, '_deserialize': deserialize_integer, + '_dataclass': dataclass_integer, '_description': '64-bit integer', '_inherit': 'number'}, @@ -2372,6 +2491,7 @@ def register_units(core, units): '_check': check_float, '_deserialize': deserialize_float, '_divide': divide_float, + '_dataclass': dataclass_float, '_description': '64-bit floating point precision number', '_inherit': 'number'}, @@ -2382,6 +2502,7 @@ def register_units(core, units): '_apply': replace, '_serialize': serialize_string, '_deserialize': deserialize_string, + '_dataclass': dataclass_string, '_description': '64-bit integer'}, 'list': { @@ -2392,12 +2513,12 @@ def register_units(core, units): '_apply': apply_list, '_serialize': serialize_list, '_deserialize': deserialize_list, + '_dataclass': dataclass_list, '_fold': fold_list, '_divide': divide_list, '_type_parameters': ['element'], '_description': 'general list type (or sublists)'}, - # TODO: tree should behave as if the leaf type is a valid tree 'tree': { '_type': 'tree', '_default': {}, @@ -2406,6 +2527,7 @@ def register_units(core, units): '_apply': apply_tree, '_serialize': serialize_tree, '_deserialize': deserialize_tree, + '_dataclass': dataclass_tree, '_fold': fold_tree, '_divide': divide_tree, '_type_parameters': ['leaf'], @@ -2417,6 +2539,7 @@ def register_units(core, units): '_apply': apply_map, '_serialize': serialize_map, '_deserialize': deserialize_map, + '_dataclass': dataclass_map, '_check': check_map, '_slice': slice_map, '_fold': fold_map, @@ -2433,6 +2556,7 @@ def register_units(core, units): '_apply': apply_array, '_serialize': serialize_array, '_deserialize': deserialize_array, + '_dataclass': dataclass_array, '_type_parameters': [ 'shape', 'data'], @@ -2446,6 +2570,7 @@ def register_units(core, units): '_slice': slice_maybe, '_serialize': serialize_maybe, '_deserialize': deserialize_maybe, + '_dataclass': dataclass_maybe, '_fold': fold_maybe, '_type_parameters': ['value'], '_description': 'type to represent values that could be empty'}, @@ -2465,7 +2590,6 @@ def register_units(core, units): '_apply': apply_schema}, 'edge': { - # TODO: do we need to have defaults informed by type parameters? '_type': 'edge', '_default': { 'inputs': {}, @@ -2473,6 +2597,7 @@ def register_units(core, units): '_apply': apply_edge, '_serialize': serialize_edge, '_deserialize': deserialize_edge, + '_dataclass': dataclass_edge, '_check': check_edge, '_type_parameters': ['inputs', 'outputs'], '_description': 'hyperedges in the bigraph, with inputs and outputs as type parameters', @@ -2480,10 +2605,110 @@ def register_units(core, units): 'outputs': 'wires'}} +def add_reaction(schema, state, reaction, core): + path = reaction.get('path') + + redex = {} + establish_path( + redex, + path) + + reactum = {} + node = establish_path( + reactum, + path) + + deep_merge( + node, + reaction.get('add', {})) + + return { + 'redex': redex, + 'reactum': reactum} + + +def remove_reaction(schema, state, reaction, core): + path = reaction.get('path', ()) + redex = {} + node = establish_path( + redex, + path) + + for remove in reaction.get('remove', []): + node[remove] = {} + + reactum = {} + establish_path( + reactum, + path) + + return { + 'redex': redex, + 'reactum': reactum} + + +def replace_reaction(schema, state, reaction, core): + path = reaction.get('path', ()) + + redex = {} + node = establish_path( + redex, + path) + + for before_key, before_state in reaction.get('before', {}).items(): + node[before_key] = before_state + + reactum = {} + node = establish_path( + reactum, + path) + + for after_key, after_state in reaction.get('after', {}).items(): + node[after_key] = after_state + + return { + 'redex': redex, + 'reactum': reactum} + + +def divide_reaction(schema, state, reaction, core): + mother = reaction['mother'] + daughters = reaction.get( + 'daughters', + [f'{mother}0', f'{mother}1']) + + mother_schema, mother_state = core.slice( + schema, + state, + mother) + + division = core.fold( + mother_schema, + mother_state, + 'divide', { + 'divisions': len(daughters)}) + + after = { + daughter: daughter_state + for daughter, daughter_state in zip(daughters, division)} + + replace = { + 'before': { + mother: {}}, + 'after': after} + + return replace_reaction( + schema, + state, + replace, + core) + + def register_base_reactions(core): - core.register_reaction( - 'divide_counts', - react_divide_counts) + core.register_reaction('add', add_reaction) + core.register_reaction('remove', remove_reaction) + core.register_reaction('replace', replace_reaction) + core.register_reaction('divide', divide_reaction) def register_cube(core): @@ -3442,31 +3667,6 @@ def test_add_reaction(compartment_types): 'counts': {'A': 13}, 'inner': {}}}}} - def add_reaction(config): - path = config.get('path') - - redex = {} - establish_path( - redex, - path) - - reactum = {} - node = establish_path( - reactum, - path) - - deep_merge( - node, - config.get('add', {})) - - return { - 'redex': redex, - 'reactum': reactum} - - compartment_types.react_registry.register( - 'add', - add_reaction) - add_config = { 'path': ['environment', 'inner'], 'add': { @@ -3508,30 +3708,6 @@ def test_remove_reaction(compartment_types): 'counts': {'A': 13}, 'inner': {}}}}} - # TODO: register these for general access - def remove_reaction(config): - path = config.get('path', ()) - redex = {} - node = establish_path( - redex, - path) - - for remove in config.get('remove', []): - node[remove] = {} - - reactum = {} - establish_path( - reactum, - path) - - return { - 'redex': redex, - 'reactum': reactum} - - compartment_types.react_registry.register( - 'remove', - remove_reaction) - remove_config = { 'path': ['environment', 'inner'], 'remove': ['0']} @@ -3566,43 +3742,16 @@ def test_replace_reaction(compartment_types): 'counts': {'A': 13}, 'inner': {}}}}} - def replace_reaction(config): - path = config.get('path', ()) - - redex = {} - node = establish_path( - redex, - path) - - for before_key, before_state in config.get('before', {}).items(): - node[before_key] = before_state - - reactum = {} - node = establish_path( - reactum, - path) - - for after_key, after_state in config.get('after', {}).items(): - node[after_key] = after_state - - return { - 'redex': redex, - 'reactum': reactum} - - compartment_types.react_registry.register( - 'replace', - replace_reaction) - - replace_config = { - 'path': ['environment', 'inner'], - 'before': {'0': {'A': '?1'}}, - 'after': { - '2': { - 'counts': { - 'A': {'function': 'divide', 'arguments': ['?1', 0.5], }}}, - '3': { - 'counts': { - 'A': '@1'}}}} + # replace_config = { + # 'path': ['environment', 'inner'], + # 'before': {'0': {'A': '?1'}}, + # 'after': { + # '2': { + # 'counts': { + # 'A': {'function': 'divide', 'arguments': ['?1', 0.5], }}}, + # '3': { + # 'counts': { + # 'A': '@1'}}}} replace_config = { 'path': ['environment', 'inner'], @@ -4380,6 +4529,130 @@ def test_set_slice(core): 1])[1] == 33 +def from_state(dataclass, state): + if hasattr(dataclass, '__dataclass_fields__'): + fields = dataclass.__dataclass_fields__ + state = state or {} + + init = {} + for key, field in fields.items(): + substate = from_state( + field.type, + state.get(key)) + init[key] = substate + instance = dataclass(**init) + # elif get_origin(dataclass) in [typing.Union, typing.Mapping]: + # instance = state + else: + instance = state + # instance = dataclass(state) + + return instance + + +def test_dataclass(core): + simple_schema = { + 'a': 'float', + 'b': 'integer', + 'c': 'boolean', + 'x': 'string'} + + # TODO: accept just a string instead of only a path + simple_dataclass = core.dataclass( + simple_schema, + ['simple']) + + simple_state = { + 'a': 88.888, + 'b': 11111, + 'c': False, + 'x': 'not a string'} + + simple_new = simple_dataclass( + a=1.11, + b=33, + c=True, + x='what') + + simple_from = from_state( + simple_dataclass, + simple_state) + + nested_schema = { + 'a': { + 'a': { + 'a': 'float', + 'b': 'float'}, + 'x': 'float'}} + + nested_dataclass = core.dataclass( + nested_schema, + ['nested']) + + nested_state = { + 'a': { + 'a': { + 'a': 13.4444, + 'b': 888.88}, + 'x': 111.11111}} + + nested_new = data.nested( + data.nested_a( + data.nested_a_a( + a=222.22, + b=3.3333), + 5555.55)) + + nested_from = from_state( + nested_dataclass, + nested_state) + + complex_schema = { + 'a': 'tree[maybe[float]]', + 'b': 'float~list[string]', + 'c': { + 'd': 'map[edge[GGG:float,OOO:float]]', + 'e': 'array[(3|4|10),float]'}} + + complex_dataclass = core.dataclass( + complex_schema, + ['complex']) + + complex_state = { + 'a': { + 'x': { + 'oooo': None, + 'y': 1.1, + 'z': 33.33}, + 'w': 44.444}, + 'b': ['1', '11', '111', '1111'], + 'c': { + 'd': { + 'A': { + 'inputs': { + 'GGG': ['..', '..', 'a', 'w']}, + 'outputs': { + 'OOO': ['..', '..', 'a', 'x', 'y']}}, + 'B': { + 'inputs': { + 'GGG': ['..', '..', 'a', 'x', 'y']}, + 'outputs': { + 'OOO': ['..', '..', 'a', 'w']}}}, + 'e': np.zeros((3, 4, 10))}} + + complex_from = from_state( + complex_dataclass, + complex_state) + + complex_dict = asdict(complex_from) + + # assert complex_dict == complex_state ? + + assert complex_from.a['x']['oooo'] is None + assert len(complex_from.c.d['A']['inputs']['GGG']) + assert isinstance(complex_from.c.e, np.ndarray) + + if __name__ == '__main__': core = TypeSystem() @@ -4422,4 +4695,4 @@ def test_set_slice(core): test_bind(core) test_slice(core) test_set_slice(core) - + test_dataclass(core)