diff --git a/src/anemoi/transform/sources/mars.py b/src/anemoi/transform/sources/mars.py index 06fc48f..41ea5e0 100644 --- a/src/anemoi/transform/sources/mars.py +++ b/src/anemoi/transform/sources/mars.py @@ -36,6 +36,3 @@ def forward(self, data): return this.forward(self.data) return Input(data) - - -source_registry.register("mars", Mars) diff --git a/src/anemoi/transform/variables/__init__.py b/src/anemoi/transform/variables/__init__.py index 830d6ae..a0e98e1 100644 --- a/src/anemoi/transform/variables/__init__.py +++ b/src/anemoi/transform/variables/__init__.py @@ -83,3 +83,10 @@ def is_computed_forcing(self): @property def is_from_input(self): pass + + def similarity(self, other): + """Compute the similarity between two variables. This is used when + encoding a variables in GRIB and we do not have a template for it. + We can then try to find the most similar variable for which we have a template. + """ + return 0 diff --git a/src/anemoi/transform/variables/variables.py b/src/anemoi/transform/variables/variables.py index cb74ee9..aa64b08 100644 --- a/src/anemoi/transform/variables/variables.py +++ b/src/anemoi/transform/variables/variables.py @@ -51,6 +51,21 @@ def is_instantanous(self): def grib_keys(self): return self.data.get("mars", {}).copy() + def similarity(self, other): + if not isinstance(other, VariableFromMarsVocabulary): + return 0 + + def __similarity(a, b): + if isinstance(a, dict) and isinstance(b, dict): + return sum(__similarity(a[k], b[k]) for k in set(a.keys()) & set(b.keys())) + + if isinstance(a, list) and isinstance(b, list): + return sum(__similarity(a[i], b[i]) for i in range(min(len(a), len(b)))) + + return 1 if a == b else 0 + + return __similarity(self.data, other.data) + class VariableFromDict(VariableFromMarsVocabulary): """A variable that is defined by a user provided dictionary.""" diff --git a/src/anemoi/transform/workflows/pipeline.py b/src/anemoi/transform/workflows/pipeline.py index 73e29bb..86413f4 100644 --- a/src/anemoi/transform/workflows/pipeline.py +++ b/src/anemoi/transform/workflows/pipeline.py @@ -12,6 +12,7 @@ from . import workflow_registry +@workflow_registry.register("pipeline") class Pipeline(Workflow): """A simple pipeline of filters""" @@ -27,6 +28,3 @@ def backward(self, data): for filter in reversed(self.filters): data = filter.backward(data) return data - - -workflow_registry.register("pipeline", Pipeline)