Skip to content

Commit

Permalink
Move the basic data type from workgraph to there (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 authored Nov 28, 2024
1 parent 2b49a8c commit 3cb836d
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 13 deletions.
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
[project.entry-points."aiida.data"]
"pythonjob.pickled_data" = "aiida_pythonjob.data.pickled_data:PickledData"
"pythonjob.pickled_function" = "aiida_pythonjob.data.pickled_function:PickledFunction"
"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData"
"pythonjob.builtins.int" = "aiida.orm.nodes.data.int:Int"
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"
"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List"
"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict"


[project.entry-points."aiida.calculations"]
"pythonjob.pythonjob" = "aiida_pythonjob.calculations.pythonjob:PythonJob"
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_pythonjob/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pickled_data import PickledData
from .pickled_function import PickledFunction
from .serializer import general_serializer, serialize_to_aiida_nodes

__all__ = ("PickledData", "PickledFunction")
__all__ = ("PickledData", "PickledFunction", "serialize_to_aiida_nodes", "general_serializer")
53 changes: 53 additions & 0 deletions src/aiida_pythonjob/data/atoms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
from aiida.orm import Data
from ase import Atoms
from ase.db.row import atoms2dict

__all__ = ("AtomsData",)


class AtomsData(Data):
"""Data to represent a ASE Atoms."""

_cached_atoms = None

def __init__(self, value=None, **kwargs):
"""Initialise a `AtomsData` node instance.
:param value: ASE Atoms instance to initialise the `AtomsData` node from
"""
atoms = value or Atoms()
super().__init__(**kwargs)
data, keys = self.atoms2dict(atoms)
self.base.attributes.set_many(data)
self.base.attributes.set("keys", keys)

@classmethod
def atoms2dict(cls, atoms):
data = atoms2dict(atoms)
data.pop("unique_id")
keys = list(data.keys())
formula = atoms.get_chemical_formula()
data = cls._convert_numpy_to_native(data)
data["formula"] = formula
data["symbols"] = atoms.get_chemical_symbols()
return data, keys

@classmethod
def _convert_numpy_to_native(cls, data):
"""Convert numpy types to Python native types for JSON compatibility."""
for key, value in data.items():
if isinstance(value, np.bool_):
data[key] = bool(value)
elif isinstance(value, np.ndarray):
data[key] = value.tolist()
elif isinstance(value, np.generic):
data[key] = value.item()
return data

@property
def value(self):
keys = self.base.attributes.get("keys")
data = self.base.attributes.get_many(keys)
data = dict(zip(keys, data))
return Atoms(**data)
13 changes: 13 additions & 0 deletions src/aiida_pythonjob/data/data_with_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from aiida import orm


class Dict(orm.Dict):
@property
def value(self):
return self.get_dict()


class List(orm.List):
@property
def value(self):
return self.get_list()
12 changes: 0 additions & 12 deletions src/aiida_pythonjob/data/pickled_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@
from aiida import orm


class Dict(orm.Dict):
@property
def value(self):
return self.get_dict()


class List(orm.List):
@property
def value(self):
return self.get_list()


class PickledData(orm.Data):
"""Data to represent a pickled value using cloudpickle."""

Expand Down
13 changes: 13 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import aiida
from aiida_pythonjob import PickledFunction


Expand All @@ -21,3 +22,15 @@ def generate_structures(
"builtins": {"list", "float"},
"numpy": {"array"},
}


def test_python_job():
"""Test a simple python node."""
from aiida_pythonjob.data.pickled_data import PickledData
from aiida_pythonjob.data.serializer import serialize_to_aiida_nodes

inputs = {"a": 1, "b": 2.0, "c": set()}
new_inputs = serialize_to_aiida_nodes(inputs)
assert isinstance(new_inputs["a"], aiida.orm.Int)
assert isinstance(new_inputs["b"], aiida.orm.Float)
assert isinstance(new_inputs["c"], PickledData)

0 comments on commit 3cb836d

Please sign in to comment.