Skip to content

Commit

Permalink
work on processors
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 7, 2025
1 parent 184de69 commit 0277476
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
47 changes: 45 additions & 2 deletions src/anemoi/transform/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def __getattr__(self, name):
return getattr(self._field, name)

def __repr__(self) -> str:
return repr(self._field)
return f"{self.__class__.__name__ }({repr(self._field)})"

def clone(self, **kwargs):
return NewClonedField(self, **kwargs)


class NewDataField(WrappedField):
Expand Down Expand Up @@ -159,6 +162,20 @@ def __init__(self, field, **kwargs):

def metadata(self, *args, **kwargs):

this = self

if len(args) == 0 and len(kwargs) == 0:

class MD:

def get(self, key, default=None):
if key in this._metadata:
return this._metadata[key]

return this._field.metadata().get(key, default)

return MD()

if kwargs.get("namespace"):
assert len(args) == 0, (args, kwargs)
mars = self._field.metadata(**kwargs).copy()
Expand All @@ -168,10 +185,16 @@ def metadata(self, *args, **kwargs):
return mars

if len(args) == 1 and args[0] in self._metadata:
return self._metadata[args[0]]
value = self._metadata[args[0]]
if callable(value):
return value(self, args[0], self._field.metadata())
return value

return self._field.metadata(*args, **kwargs)

def __repr__(self) -> str:
return f"{self.__class__.__name__ }({repr(self._field)},{self._metadata})"


class NewValidDateTimeField(NewMetadataField):
"""Change the valid_datetime of a field."""
Expand All @@ -186,6 +209,26 @@ def __init__(self, field, valid_datetime):
super().__init__(field, date=date, time=time, step=0, valid_datetime=valid_datetime.isoformat())


class NewClonedField(WrappedField):
"""Wrapper around a field object that clones the field."""

def __init__(self, field, **metadata):
super().__init__(field)
self._metadata = metadata

def metadata(self, *args, **kwargs):
if len(args) == 1:
if args[0] in self._metadata:
if callable(self._metadata[args[0]]):
proc = self._metadata[args[0]]
self._metadata[args[0]] = proc(self._field, args[0], self._field.metadata())

if args[0] in self._metadata:
return self._metadata[args[0]]

return self._field.metadata(*args, **kwargs)


def new_field_from_numpy(array, *, template=None, **metadata):
return NewMetadataField(NewDataField(template, array), **metadata)

Expand Down
15 changes: 15 additions & 0 deletions src/anemoi/transform/filters/cos_sin_mean_wave_direction.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,20 @@ def backward_processor(self, state):

return state

def patch_data_request(self, data_request):
"""We have a chance to modify the data request here."""

param = data_request.get("param")
if param is None:
return data_request

if self.cos_mean_wave_direction in param or self.sin_mean_wave_direction in param:
data_request["param"] = [
p for p in param if p not in (self.cos_mean_wave_direction, self.sin_mean_wave_direction)
]
data_request["param"].append(self.mean_wave_direction)

return data_request


filter_registry.register("mean_wave_direction", CosSinWaveDirection.reversed)
6 changes: 6 additions & 0 deletions src/anemoi/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def forward_processor(self, state):
def backward_processor(self, state):
raise NotImplementedError("Not implemented")

def patch_data_request(self, data_request):
return data_request


class ReversedTransform(Transform):
"""Swap the forward and backward methods of a filter."""
Expand All @@ -67,3 +70,6 @@ def forward_processor(self, state):

def backward_processor(self, state):
return self.filter.forward_processor(state)

def patch_data_request(self, data_request):
return self.filter.patch_data_request(data_request)

0 comments on commit 0277476

Please sign in to comment.