Skip to content

Commit

Permalink
using dict everywhere wip
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Oct 10, 2024
1 parent 63cd66d commit ebc214d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
52 changes: 45 additions & 7 deletions src/anemoi/models/data_indices/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
# nor does it submit to any jurisdiction.
#

from collections import defaultdict
import torch


class BaseTensorIndex:
"""Indexing for variables in index as Tensor."""

def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None:
def __init__(
self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]
) -> None:
"""Initialize indexing tensors from includes and excludes using name_to_index.
Parameters
Expand Down Expand Up @@ -85,30 +88,65 @@ def representer(dumper, data):
def _build_idx_from_excludes(self, excludes=None) -> "torch.Tensor[int]":
if excludes is None:
excludes = self.excludes
return torch.Tensor(sorted(i for name, i in self.name_to_index.items() if name not in excludes)).to(torch.int)
return self._build_idx_from_condition(lambda name: name not in excludes)

def _build_idx_from_includes(self, includes=None) -> "torch.Tensor[int]":
if includes is None:
includes = self.includes
return torch.Tensor(sorted(self.name_to_index[name] for name in includes)).to(torch.int)
return self._build_idx_from_condition(lambda name: name in includes)

def _build_idx_prognostic(self) -> "torch.Tensor[int]":
return self._build_idx_from_excludes(self.includes + self.excludes)

def _build_idx_from_condition(self, condition):
# refactor to use two different classes
typ = type((list(self.name_to_index.values()))[0])
print(self.name_to_index, typ)

func = {
int: self._build_idx_from_condition_todo_i,
tuple: self._build_idx_from_condition_todo_dict,
}[typ]
return func(condition)

def _build_idx_from_condition_todo_dict(self, condition):
idx = defaultdict(list)
for name, (i, j) in self.name_to_index.items():
assert isinstance(j, int), j
if condition(name):
idx[i].append(j)
return {k: torch.Tensor(sorted(v)).to(torch.int) for k, v in idx.items()}

def _build_idx_from_condition_todo_i(self, condition):
idx = []
for name, i in self.name_to_index.items():
assert isinstance(i, int), i
if condition(name):
idx.append(i)
return torch.Tensor(sorted(idx)).to(torch.int)


class InputTensorIndex(BaseTensorIndex):
"""Indexing for input variables."""

def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None:
super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index)
def __init__(
self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]
) -> None:
super().__init__(
includes=includes, excludes=excludes, name_to_index=name_to_index
)
self.forcing = self._only
self.diagnostic = self._removed


class OutputTensorIndex(BaseTensorIndex):
"""Indexing for output variables."""

def __init__(self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]) -> None:
super().__init__(includes=includes, excludes=excludes, name_to_index=name_to_index)
def __init__(
self, *, includes: list[str], excludes: list[str], name_to_index: dict[str, int]
) -> None:
super().__init__(
includes=includes, excludes=excludes, name_to_index=name_to_index
)
self.forcing = self._removed
self.diagnostic = self._only
19 changes: 11 additions & 8 deletions src/anemoi/models/preprocessing/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
name_to_index_training_input = self.data_indices.data.input.name_to_index
assert all(isinstance(_, str) for _ in name_to_index_training_input.keys()), name_to_index_training_input

print('✅',statistics)
print('✅,',statistics)
minimum = statistics["minimum"]
maximum = statistics["maximum"]
mean = statistics["mean"]
Expand Down Expand Up @@ -100,12 +100,14 @@ def __init__(
raise ValueError[f"Unknown normalisation method for {name}: {method}"]

# register buffer - this will ensure they get copied to the correct device(s)
_norm_mul = _norm_mul.flatten()
_norm_add = _norm_add.flatten()
self.register_buffer("_norm_mul", torch.from_numpy(_norm_mul), persistent=True)
self.register_buffer("_norm_add", torch.from_numpy(_norm_add), persistent=True)
self.register_buffer("_input_idx", data_indices.data.input.full, persistent=True)
self.register_buffer("_output_idx", self.data_indices.data.output.full, persistent=True)
_norm_mul.as_torch().register_buffer(name="_norm_mul", persistent=True, caller=self)
_norm_add.as_torch().register_buffer(name="_norm_add", persistent=True, caller=self)

# this should go in a class or in a method
for k, v in data_indices.data.input.full.items():
self.register_buffer(f"_input_idx__{k}", v, persistent=True)
for k, v in self.data_indices.data.output.full.items():
self.register_buffer(f"_output_idx__{k}", v, persistent=True)

def _validate_normalization_inputs(self, name_to_index_training_input: dict, minimum, maximum, mean, stdev):
assert len(self.methods) == sum(len(v) for v in self.method_config.values()), (
Expand All @@ -118,7 +120,8 @@ def _validate_normalization_inputs(self, name_to_index_training_input: dict, min
assert stdev.size == n, (stdev.size, n)

for name, (i,j) in name_to_index_training_input.items():
assert i < len(minimum.arrays), ((i,j), name, [v.size for v in minimum.arrays] ,'💬', name_to_index_training_input)
if isinstance(i, int):
assert i < len(minimum.arrays), ((i,j), name, [v.size for v in minimum.arrays] ,'💬', name_to_index_training_input)
assert j < minimum.arrays[i].size, ((i,j), name,minimum.arrays[i].size, '💬',name_to_index_training_input)

assert isinstance(self.methods, dict)
Expand Down

0 comments on commit ebc214d

Please sign in to comment.