Skip to content

Commit

Permalink
Make ahmanalysis ert script
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Feb 11, 2025
1 parent 6ca4cd4 commit 1dc4558
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 70 deletions.
121 changes: 67 additions & 54 deletions src/semeio/workflows/ahm_analysis/ahmanalysis.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import collections
import itertools
import logging
import os
import tempfile
import warnings
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import asdict
from pathlib import Path
from typing import Any

import ert
import numpy as np
import pandas as pd
import polars
import polars as pl
from ert import LibresFacade
from ert import ErtScript, LibresFacade
from ert.analysis import ErtAnalysisError, SmootherSnapshot, smoother_update
from ert.config import ESSettings, Field, GenKwConfig, UpdateSettings
from ert.storage import open_storage
from ert.storage import Ensemble, Storage, open_storage
from scipy.stats import ks_2samp

from semeio._exceptions.exceptions import ValidationError
from semeio.communication import SemeioScript

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,29 +106,28 @@
"""


class AhmAnalysisJob(SemeioScript):
class AhmAnalysisJob(ErtScript):
"""Define ERT workflow to evaluate change of parameters for eac
observation during history matching
"""

def run(
self,
alpha: str | None = None,
target_name="analysis_case",
prior_name=None,
group_by="data_key",
output_dir=None,
):
storage: Storage,
es_settings: ESSettings,
observation_settings: UpdateSettings,
random_seed: int,
reports_dir: str,
ensemble: Ensemble,
alpha: int | None = None,
target_name: str = "analysis_case",
prior_name: str | None = None,
group_by: str = "data_key",
) -> Any:
"""Perform analysis of parameters change per obs group
prior to posterior of ahm"""

if isinstance(alpha, str):
alpha = float(alpha)

if output_dir is not None:
self._reports_dir = output_dir

experiment = self.ensemble.experiment
experiment = ensemble.experiment

observations_and_responses_mapping = (
pl.concat(
Expand All @@ -149,14 +152,14 @@ def _replace(s: str) -> str:
}

prior_name, target_name = check_names(
self.ensemble.name,
ensemble.name,
prior_name,
target_name,
)

prior_ensemble = None
# Get the prior scalar parameter distributions
for experiment in self.storage.experiments:
for experiment in storage.experiments:
try:
prior_ensemble = experiment.get_ensemble_by_name(prior_name)
break
Expand All @@ -180,6 +183,9 @@ def _replace(s: str) -> str:
except KeyError as err:
raise ValidationError(f"Empty prior ensemble: {err}") from err

ahmanalysis_reports_dir = Path(reports_dir) / "AhmAnalysisJob"
os.makedirs(ahmanalysis_reports_dir, exist_ok=True)

# create dataframe with observations vectors (1 by 1 obs and also all_obs)
combinations = make_obs_groups(key_map)

Expand Down Expand Up @@ -235,11 +241,14 @@ def _replace(s: str) -> str:
ensemble_size=prior_ensemble.ensemble_size,
)
update_log = _run_ministep(
prior_ensemble,
target_ensemble,
obs_group,
field_parameters + scalar_parameters,
alpha,
prior_storage=prior_ensemble,
target_storage=target_ensemble,
obs_group=obs_group,
data_parameters=field_parameters + scalar_parameters,
observation_settings=observation_settings,
es_settings=es_settings,
random_seed=random_seed,
alpha=alpha,
)
# Get the active vs total observation info
df_update_log = make_update_log_df(update_log)
Expand All @@ -248,8 +257,8 @@ def _replace(s: str) -> str:
del updated_combinations[group_name]
continue
# Get the updated scalar parameter distributions
self.reporter.publish_csv(
group_name, target_ensemble.load_all_gen_kw_data()
target_ensemble.load_all_gen_kw_data().to_csv(
ahmanalysis_reports_dir / f"{group_name}.csv"
)

active_obs.at["ratio", group_name] = (
Expand Down Expand Up @@ -278,35 +287,11 @@ def _replace(s: str) -> str:
kolmogorov_smirnov_data.set_index("Parameters", inplace=True)

# save/export the Ks matrix, active_obs, misfitval and prior data
self.reporter.publish_csv("ks", kolmogorov_smirnov_data)
self.reporter.publish_csv("active_obs_info", active_obs)
self.reporter.publish_csv("misfit_obs_info", misfitval)
self.reporter.publish_csv("prior", prior_data)


def _run_ministep(
prior_storage,
target_storage,
obs_group,
data_parameters,
alpha: float | None = None,
):
es_settings = ESSettings()
obs_settings = (
UpdateSettings(alpha=alpha) if alpha is not None else UpdateSettings()
)

rng = np.random.default_rng()

return smoother_update(
prior_storage=prior_storage,
posterior_storage=target_storage,
observations=obs_group,
parameters=data_parameters,
update_settings=obs_settings,
es_settings=es_settings,
rng=rng,
)
kolmogorov_smirnov_data.to_csv(ahmanalysis_reports_dir / "ks.csv")
active_obs.to_csv(ahmanalysis_reports_dir / "active_obs_info.csv")
misfitval.to_csv(ahmanalysis_reports_dir / "misfit_obs_info.csv")
prior_data.to_csv(ahmanalysis_reports_dir / "prior.csv")


def make_update_log_df(update_log: SmootherSnapshot) -> pd.DataFrame:
Expand Down Expand Up @@ -345,6 +330,34 @@ def make_update_log_df(update_log: SmootherSnapshot) -> pd.DataFrame:
return updatelog


def _run_ministep(
prior_storage: Ensemble,
target_storage: Ensemble,
obs_group: Iterable[str],
data_parameters: Iterable[str],
observation_settings: UpdateSettings,
es_settings: ESSettings,
random_seed: int,
alpha: int | None = None,
) -> SmootherSnapshot:
rng = np.random.default_rng(random_seed)

if alpha is not None:
observation_settings = UpdateSettings(
**{**asdict(observation_settings), "alpha": alpha}
)

return smoother_update(
prior_storage=prior_storage,
posterior_storage=target_storage,
observations=obs_group,
parameters=data_parameters,
update_settings=observation_settings,
es_settings=es_settings,
rng=rng,
)


def make_obs_groups(key_map):
"""Create a mapping of observation groups, the names will be:
data_key -> [obs_keys] and All_obs-{missing_obs} -> [obs_keys]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
,Parameters,FOPR,All_obs,All_obs-WOPR_OP1,All_obs-SNAKE_OIL_WPR_DIFF
0,SNAKE_OIL_PARAM:BPR_138_PERSISTENCE,0.3,0.3,0.4,0.3
1,SNAKE_OIL_PARAM:BPR_555_PERSISTENCE,0.2,0.3,0.2,0.3
2,SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE,0.4,0.3,0.3,0.4
3,SNAKE_OIL_PARAM:OP1_OCTAVES,0.3,0.2,0.2,0.3
4,SNAKE_OIL_PARAM:OP1_OFFSET,0.2,0.3,0.2,0.2
5,SNAKE_OIL_PARAM:OP1_PERSISTENCE,0.2,0.2,0.1,0.2
6,SNAKE_OIL_PARAM:OP2_DIVERGENCE_SCALE,0.3,0.2,0.3,0.3
7,SNAKE_OIL_PARAM:OP2_OCTAVES,0.2,0.2,0.2,0.3
8,SNAKE_OIL_PARAM:OP2_OFFSET,0.4,0.5,0.4,0.4
9,SNAKE_OIL_PARAM:OP2_PERSISTENCE,0.3,0.3,0.2,0.1
0,SNAKE_OIL_PARAM:BPR_138_PERSISTENCE,0.3,0.3,0.3,0.3
1,SNAKE_OIL_PARAM:BPR_555_PERSISTENCE,0.2,0.2,0.2,0.2
2,SNAKE_OIL_PARAM:OP1_DIVERGENCE_SCALE,0.4,0.4,0.4,0.4
3,SNAKE_OIL_PARAM:OP1_OCTAVES,0.2,0.2,0.2,0.2
4,SNAKE_OIL_PARAM:OP1_OFFSET,0.3,0.3,0.3,0.3
5,SNAKE_OIL_PARAM:OP1_PERSISTENCE,0.2,0.2,0.2,0.2
6,SNAKE_OIL_PARAM:OP2_DIVERGENCE_SCALE,0.2,0.2,0.2,0.2
7,SNAKE_OIL_PARAM:OP2_OCTAVES,0.2,0.2,0.2,0.2
8,SNAKE_OIL_PARAM:OP2_OFFSET,0.5,0.5,0.5,0.5
9,SNAKE_OIL_PARAM:OP2_PERSISTENCE,0.2,0.2,0.2,0.2
11 changes: 7 additions & 4 deletions tests/workflows/ahm_analysis/test_ahm_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ def test_make_update_log_df(snake_oil_facade, snapshot):
prior_ensemble=prior_ens,
)
log = _run_ministep(
prior_ens,
posterior_ens,
sorted(prior_ens.experiment.observation_keys),
sorted(prior_ens.experiment.parameter_configuration.keys()),
prior_storage=prior_ens,
target_storage=posterior_ens,
obs_group=sorted(prior_ens.experiment.observation_keys),
data_parameters=sorted(prior_ens.experiment.parameter_configuration.keys()),
observation_settings=snake_oil_facade.config.analysis_config.observation_settings,
es_settings=snake_oil_facade.config.analysis_config.es_module,
random_seed=snake_oil_facade.config.random_seed,
)
snapshot.assert_match(
ahmanalysis.make_update_log_df(log).round(4).to_csv(),
Expand Down
3 changes: 1 addition & 2 deletions tests/workflows/ahm_analysis/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def test_ahmanalysis_run_deactivated_obs(snake_oil_facade, snapshot, caplog):
We simulate a case where some of the observation groups are completely
disabled by outlier detection
"""
# Note: Unattainable run unless we can pass alpha to the workflow somehow
with (
open_storage(snake_oil_facade.enspath, "w") as storage,
caplog.at_level(logging.WARNING),
Expand All @@ -78,7 +77,7 @@ def test_ahmanalysis_run_deactivated_obs(snake_oil_facade, snapshot, caplog):
ahmanalysis.AhmAnalysisJob,
storage,
experiment.get_ensemble_by_name("default"),
0.1,
alpha=0.1,
)
assert "Analysis failed for" in caplog.text

Expand Down

0 comments on commit 1dc4558

Please sign in to comment.