Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move the basic data type from workgraph to there #2

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ Source = "https://github.com/aiidateam/aiida-pythonjob"
[project.entry-points."aiida.data"]
"pythonjob.pickled_data" = "aiida_pythonjob.data.pickled_data:PickledData"
"pythonjob.pickled_function" = "aiida_pythonjob.data.pickled_function:PickledFunction"
"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData"
"pythonjob.builtins.int" = "aiida.orm.nodes.data.int:Int"
"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float"
"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str"
"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool"
"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List"
"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict"


[project.entry-points."aiida.calculations"]
"pythonjob.pythonjob" = "aiida_pythonjob.calculations.pythonjob:PythonJob"
Expand Down
3 changes: 2 additions & 1 deletion src/aiida_pythonjob/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .pickled_data import PickledData
from .pickled_function import PickledFunction
from .serializer import general_serializer, serialize_to_aiida_nodes

__all__ = ("PickledData", "PickledFunction")
__all__ = ("PickledData", "PickledFunction", "serialize_to_aiida_nodes", "general_serializer")
53 changes: 53 additions & 0 deletions src/aiida_pythonjob/data/atoms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
from aiida.orm import Data
from ase import Atoms
from ase.db.row import atoms2dict

__all__ = ("AtomsData",)


class AtomsData(Data):
"""Data to represent a ASE Atoms."""

_cached_atoms = None

def __init__(self, value=None, **kwargs):
"""Initialise a `AtomsData` node instance.

:param value: ASE Atoms instance to initialise the `AtomsData` node from
"""
atoms = value or Atoms()
super().__init__(**kwargs)
data, keys = self.atoms2dict(atoms)
self.base.attributes.set_many(data)
self.base.attributes.set("keys", keys)

@classmethod
def atoms2dict(cls, atoms):
data = atoms2dict(atoms)
data.pop("unique_id")
keys = list(data.keys())
formula = atoms.get_chemical_formula()
data = cls._convert_numpy_to_native(data)
data["formula"] = formula
data["symbols"] = atoms.get_chemical_symbols()
return data, keys

@classmethod
def _convert_numpy_to_native(cls, data):
"""Convert numpy types to Python native types for JSON compatibility."""
for key, value in data.items():
if isinstance(value, np.bool_):
data[key] = bool(value)
elif isinstance(value, np.ndarray):
data[key] = value.tolist()
elif isinstance(value, np.generic):
data[key] = value.item()
return data

@property
def value(self):
keys = self.base.attributes.get("keys")
data = self.base.attributes.get_many(keys)
data = dict(zip(keys, data))
return Atoms(**data)
13 changes: 13 additions & 0 deletions src/aiida_pythonjob/data/data_with_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from aiida import orm


class Dict(orm.Dict):
@property
def value(self):
return self.get_dict()


class List(orm.List):
@property
def value(self):
return self.get_list()
12 changes: 0 additions & 12 deletions src/aiida_pythonjob/data/pickled_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,6 @@
from aiida import orm


class Dict(orm.Dict):
@property
def value(self):
return self.get_dict()


class List(orm.List):
@property
def value(self):
return self.get_list()


class PickledData(orm.Data):
"""Data to represent a pickled value using cloudpickle."""

Expand Down
13 changes: 13 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import aiida
from aiida_pythonjob import PickledFunction


Expand All @@ -21,3 +22,15 @@ def generate_structures(
"builtins": {"list", "float"},
"numpy": {"array"},
}


def test_python_job():
"""Test a simple python node."""
from aiida_pythonjob.data.pickled_data import PickledData
from aiida_pythonjob.data.serializer import serialize_to_aiida_nodes

inputs = {"a": 1, "b": 2.0, "c": set()}
new_inputs = serialize_to_aiida_nodes(inputs)
assert isinstance(new_inputs["a"], aiida.orm.Int)
assert isinstance(new_inputs["b"], aiida.orm.Float)
assert isinstance(new_inputs["c"], PickledData)
Loading