diff --git a/aiida_flexpart/data/nc_data.py b/aiida_flexpart/data/nc_data.py index 298b815..7ffa3ae 100644 --- a/aiida_flexpart/data/nc_data.py +++ b/aiida_flexpart/data/nc_data.py @@ -1,10 +1,12 @@ import os from aiida.orm import RemoteData -from netCDF4 import Dataset -class NetCDFData(RemoteData): - def __init__(self, filepath=None, remote_path=None, **kwargs): +class NetCdfData(RemoteData): + + def __init__( + self, filepath=None, remote_path=None, g_att=None, nc_dimensions=None, **kwargs + ): """ Data plugin for Netcdf files. """ @@ -13,32 +15,20 @@ def __init__(self, filepath=None, remote_path=None, **kwargs): filename = os.path.basename(filepath) self.set_remote_path(remote_path) self.set_filename(filename) - - # open and read as NetCDF - nc_file = Dataset(filepath, mode="r") - self.set_global_attributes(nc_file) + self.set_global_attributes(g_att, nc_dimensions) def set_filename(self, val): self.base.attributes.set("filename", val) - def set_global_attributes(self, nc_file): - - g_att = {} - for a in nc_file.ncattrs(): - g_att[a] = repr(nc_file.getncattr(a)) + def set_global_attributes(self, g_att, nc_dimensions): self.base.attributes.set("global_attributes", g_att) - - nc_dimensions = {i: len(nc_file.dimensions[i]) for i in nc_file.dimensions} self.base.attributes.set("dimensions", nc_dimensions) def ncdump(self): - """ - Small python version of ncdump. - """ + """Small python version of ncdump.""" print("dimensions:") for k, v in self.base.attributes.get("dimensions").items(): - print("\t%s =" % k, v) - + print(f"\t {k} = {v}") print("// global attributes:") for k, v in self.base.attributes.get("global_attributes").items(): - print("\t:%s =" % k, v) \ No newline at end of file + print(f"\t :{k} = {v}") diff --git a/aiida_flexpart/workflows/inspect.py b/aiida_flexpart/workflows/inspect.py new file mode 100644 index 0000000..c627dc7 --- /dev/null +++ b/aiida_flexpart/workflows/inspect.py @@ -0,0 +1,90 @@ +from aiida.engine import WorkChain, calcfunction +from aiida.plugins import DataFactory +from aiida import orm +from pathlib import Path +import tempfile +from netCDF4 import Dataset + +NetCDF = DataFactory("netcdf.data") + + +def check(nc_file, version): + """ + Checks if there is a netcdf file stored with the same name, + if so, it checks the created date, if that is a match then returns + False. + """ + qb = orm.QueryBuilder() + qb.append( + NetCDF, + project=[f"attributes.global_attributes.{version}"], + filters={"attributes.filename": nc_file.attributes["filename"]}, + ) + if qb.all(): + for i in qb.all(): + if i[0] == nc_file.attributes["global_attributes"][version]: + return False + return True + + +def validate_history(nc_file): + return True if "history" in nc_file.attributes["global_attributes"].keys() else None + + +@calcfunction +def store(remote_dir, file): + with tempfile.TemporaryDirectory() as td: + remote_path = Path(remote_dir.get_remote_path()) / file.value + temp_path = Path(td) / file.value + remote_dir.getfile(remote_path, temp_path) + + # fill global attributes and dimensions + nc_file = Dataset(str(temp_path), mode="r") + nc_dimensions = {i: len(nc_file.dimensions[i]) for i in nc_file.dimensions} + global_att = {} + for a in nc_file.ncattrs(): + global_att[a] = repr(nc_file.getncattr(a)) + + node = NetCDF( + str(temp_path), + remote_path=str(remote_path), + computer=remote_dir.computer, + g_att=global_att, + nc_dimensions=nc_dimensions, + ) + + if validate_history(node) == None: + return + elif check(node, "history"): + return node + + +class InspectWorkflow(WorkChain): + @classmethod + def define(cls, spec): + super().define(spec) + spec.input_namespace("remotes", valid_type=orm.RemoteData, required=False) + spec.input_namespace( + "remotes_cs", valid_type=orm.RemoteStashFolderData, required=False + ) + spec.outputs.dynamic = True + spec.outline( + cls.fill_remote_data, + cls.inspect, + ) + + def fill_remote_data(self): + self.ctx.dict_remote_data = {} + if "remotes" in self.inputs: + self.ctx.dict_remote_data = self.inputs.remotes + else: + for k, v in self.inputs.remotes_cs.items(): + self.ctx.dict_remote_data[k] = orm.RemoteData( + remote_path=v.target_basepath, computer=v.computer + ) + + def inspect(self): + for _, i in self.ctx.dict_remote_data.items(): + for file in i.listdir(): + if ".nc" in file: + store(i, file) diff --git a/pyproject.toml b/pyproject.toml index 96a4f19..f755957 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,9 +64,10 @@ docs = [ [project.entry-points."aiida.workflows"] "flexpart.multi_dates" = "aiida_flexpart.workflows.multi_dates_workflow:FlexpartMultipleDatesWorkflow" +"inspect.workflow" = "aiida_flexpart.workflows.inspect:InspectWorkflow" [project.entry-points."aiida.data"] -"netcdf.data" = "aiida_flexpart.data.nc_data:NetCDFData" +"netcdf.data" = "aiida_flexpart.data.nc_data:NetCdfData" [tool.pylint.format]