Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Functionality to Apply Constraints to Predictions #92

Merged
merged 23 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add support for multi-node training.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov

- Add option to clamp output prediction using limits specified in config file [\#92](https://github.com/mllam/neural-lam/pull/92) @SimonKamuk

### Fixed
- Only print on rank 0 to avoid duplicates of all print statements.
[\#103](https://github.com/mllam/neural-lam/pull/103) @simonkamuk @sadamov
Expand Down
23 changes: 17 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,23 @@ training:
weights:
u100m: 1.0
v100m: 1.0
```

For now the neural-lam config only defines two things: 1) the kind of data
store and the path to its config, and 2) the weighting of different features in
the loss function. If you don't define the state feature weighting it will default
to weighting all features equally.
t2m: 1.0
r2m: 1.0
output_clamping:
lower:
t2m: 0.0
r2m: 0
upper:
r2m: 1.0
```

For now the neural-lam config only defines few things:

1. The kind of datastore and the path to its config
2. The weighting of different features in
the loss function. If you don't define the state feature weighting it will default to
weighting all features equally.
3. Valid numerical range for output of each feature.The numerical range of all features default to $]-\infty, \infty[$.

(This example is taken from the `tests/datastore_examples/mdp` directory.)

Expand Down
21 changes: 21 additions & 0 deletions neural_lam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ class UniformFeatureWeighting:
pass


@dataclasses.dataclass
class OutputClamping:
"""
Configuration for clamping the output of the model.

Attributes
----------
lower : Dict[str, float]
The minimum value to clamp each output feature to.
upper : Dict[str, float]
The maximum value to clamp each output feature to.
"""

lower: Dict[str, float] = dataclasses.field(default_factory=dict)
upper: Dict[str, float] = dataclasses.field(default_factory=dict)


@dataclasses.dataclass
class TrainingConfig:
"""
Expand All @@ -86,6 +103,10 @@ class TrainingConfig:
ManualStateFeatureWeighting, UniformFeatureWeighting
] = dataclasses.field(default_factory=UniformFeatureWeighting)

output_clamping: OutputClamping = dataclasses.field(
default_factory=OutputClamping
)


@dataclasses.dataclass
class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
Expand Down
192 changes: 190 additions & 2 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,192 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
layer_norm=False,
) # No layer norm on this one

# Compute indices and define clamping functions
self.prepare_clamping_params(config, datastore)

def prepare_clamping_params(
self, config: NeuralLAMConfig, datastore: BaseDatastore
):
"""
Prepare parameters for clamping predicted values to valid range
"""

# Read configs
state_feature_names = datastore.get_vars_names(category="state")
lower_lims = config.training.output_clamping.lower
upper_lims = config.training.output_clamping.upper

# Check that limits in config are for valid features
unknown_features_lower = set(lower_lims.keys()) - set(
state_feature_names
)
unknown_features_upper = set(upper_lims.keys()) - set(
state_feature_names
)
if unknown_features_lower or unknown_features_upper:
raise ValueError(
"State feature limits were provided for unknown features: "
f"{unknown_features_lower.union(unknown_features_upper)}"
)

# Constant parameters for clamping
sigmoid_sharpness = 1
softplus_sharpness = 1
sigmoid_center = 0
softplus_center = 0

normalize_clamping_lim = (
lambda x, feature_idx: (x - self.state_mean[feature_idx])
/ self.state_std[feature_idx]
)

# Check which clamping functions to use for each feature
sigmoid_lower_upper_idx = []
sigmoid_lower_lims = []
sigmoid_upper_lims = []

softplus_lower_idx = []
softplus_lower_lims = []

softplus_upper_idx = []
softplus_upper_lims = []

for feature_idx, feature in enumerate(state_feature_names):
if feature in lower_lims and feature in upper_lims:
assert (
lower_lims[feature] < upper_lims[feature]
), f'Invalid clamping limits for feature "{feature}",\
lower: {lower_lims[feature]}, larger than\
upper: {upper_lims[feature]}'
sigmoid_lower_upper_idx.append(feature_idx)
sigmoid_lower_lims.append(
normalize_clamping_lim(lower_lims[feature], feature_idx)
)
sigmoid_upper_lims.append(
normalize_clamping_lim(upper_lims[feature], feature_idx)
)
elif feature in lower_lims and feature not in upper_lims:
softplus_lower_idx.append(feature_idx)
softplus_lower_lims.append(
normalize_clamping_lim(lower_lims[feature], feature_idx)
)
elif feature not in lower_lims and feature in upper_lims:
softplus_upper_idx.append(feature_idx)
softplus_upper_lims.append(
normalize_clamping_lim(upper_lims[feature], feature_idx)
)

self.register_buffer(
"sigmoid_lower_lims", torch.tensor(sigmoid_lower_lims)
)
self.register_buffer(
"sigmoid_upper_lims", torch.tensor(sigmoid_upper_lims)
)
self.register_buffer(
"softplus_lower_lims", torch.tensor(softplus_lower_lims)
)
self.register_buffer(
"softplus_upper_lims", torch.tensor(softplus_upper_lims)
)

self.register_buffer(
"clamp_lower_upper_idx", torch.tensor(sigmoid_lower_upper_idx)
)
self.register_buffer(
"clamp_lower_idx", torch.tensor(softplus_lower_idx)
)
self.register_buffer(
"clamp_upper_idx", torch.tensor(softplus_upper_idx)
)

# Define clamping functions
self.clamp_lower_upper = lambda x: (
self.sigmoid_lower_lims
+ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
* torch.sigmoid(sigmoid_sharpness * (x - sigmoid_center))
)
self.clamp_lower = lambda x: (
self.softplus_lower_lims
+ torch.nn.functional.softplus(
x - softplus_center, beta=softplus_sharpness
)
)
self.clamp_upper = lambda x: (
self.softplus_upper_lims
- torch.nn.functional.softplus(
softplus_center - x, beta=softplus_sharpness
)
)

self.inverse_clamp_lower_upper = lambda x: (
sigmoid_center
+ utils.inverse_sigmoid(
(x - self.sigmoid_lower_lims)
/ (self.sigmoid_upper_lims - self.sigmoid_lower_lims)
)
/ sigmoid_sharpness
)
self.inverse_clamp_lower = lambda x: (
utils.inverse_softplus(
x - self.softplus_lower_lims, beta=softplus_sharpness
)
+ softplus_center
)
self.inverse_clamp_upper = lambda x: (
-utils.inverse_softplus(
self.softplus_upper_lims - x, beta=softplus_sharpness
)
+ softplus_center
)

def get_clamped_new_state(self, state_delta, prev_state):
"""
Clamp prediction to valid range supplied in config
Returns the clamped new state after adding delta to original state

Instead of the new state being computed as
$X_{t+1} = X_t + \\delta = X_t + model(\\{X_t,X_{t-1},...\\}, forcing)$
The clamped values will be
$f(f^{-1}(X_t) + model(\\{X_t, X_{t-1},... \\}, forcing))$
Which means the model will learn to output values in the range of the
inverse clamping function

state_delta: (B, num_grid_nodes, feature_dim)
prev_state: (B, num_grid_nodes, feature_dim)
"""

# Assign new state, but overwrite clamped values of each type later
new_state = prev_state + state_delta

# Sigmoid/logistic clamps between ]a,b[
if self.clamp_lower_upper_idx.numel() > 0:
idx = self.clamp_lower_upper_idx

new_state[:, :, idx] = self.clamp_lower_upper(
self.inverse_clamp_lower_upper(prev_state[:, :, idx])
+ state_delta[:, :, idx]
)

# Softplus clamps between ]a,infty[
if self.clamp_lower_idx.numel() > 0:
idx = self.clamp_lower_idx

new_state[:, :, idx] = self.clamp_lower(
self.inverse_clamp_lower(prev_state[:, :, idx])
+ state_delta[:, :, idx]
)

# Softplus clamps between ]-infty,b[
if self.clamp_upper_idx.numel() > 0:
idx = self.clamp_upper_idx

new_state[:, :, idx] = self.clamp_upper(
self.inverse_clamp_upper(prev_state[:, :, idx])
+ state_delta[:, :, idx]
)

return new_state

def get_num_mesh(self):
"""
Compute number of mesh nodes from loaded features,
Expand Down Expand Up @@ -173,5 +359,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
# Rescale with one-step difference statistics
rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean

# Residual connection for full state
return prev_state + rescaled_delta_mean, pred_std
# Clamp values to valid range (also add the delta to the previous state)
new_state = self.get_clamped_new_state(rescaled_delta_mean, prev_state)

return new_state, pred_std
33 changes: 33 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,36 @@ def setup_training_logger(datastore, args, run_name):
)

return logger


def inverse_softplus(x, beta=1, threshold=20):
"""
Inverse of torch.nn.functional.softplus

For x*beta above threshold, returns linear function for numerical
stability.

Input is clamped to x > ln(1+1e-6)/beta which is approximately positive
values of x.
Note that this torch.clamp_min will make gradients 0, but this is not a
problem as values of x that are this close to 0 have gradients of 0 anyhow.
"""
non_linear_part = (
torch.log(torch.clamp_min(torch.expm1(x * beta), 1e-6)) / beta
)
x = torch.where(x * beta <= threshold, non_linear_part, x)

return x


def inverse_sigmoid(x):
"""
Inverse of torch.sigmoid

Sigmoid output takes values in [0,1], this makes sure input is just within
this interval.
Note that this torch.clamp will make gradients 0, but this is not a problem
as values of x that are this close to 0 or 1 have gradients of 0 anyhow.
"""
x_clamped = torch.clamp(x, min=1e-6, max=1 - 1e-6)
return torch.log(x_clamped / (1 - x_clamped))
9 changes: 9 additions & 0 deletions tests/datastore_examples/mdp/danra_100m_winds/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,12 @@ training:
weights:
u100m: 1.0
v100m: 1.0
t2m: 1.0
r2m: 1.0
output_clamping:
lower:
t2m: 0.0
r2m: 0
upper:
r2m: 1.0
u100m: 100.0
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ inputs:
dims: [x, y]
target_output_variable: state

danra_surface:
danra_surface_forcing:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr
dims: [time, x, y]
variables:
Expand All @@ -73,6 +73,24 @@ inputs:
name_format: "{var_name}"
target_output_variable: forcing

danra_surface:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr
dims: [time, x, y]
variables:
- r2m
- t2m
dim_mapping:
time:
method: rename
dim: time
grid_index:
method: stack
dims: [x, y]
state_feature:
method: stack_variables_by_var_name
name_format: "{var_name}"
target_output_variable: state

danra_static:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr
dims: [x, y]
Expand Down
Loading
Loading