Skip to content

Commit

Permalink
MAINT: Refactor caching (#124, #148, #153)
Browse files Browse the repository at this point in the history
Improve the approach in reverted commits fe23f23, 0166cd2:

- introduce `@Cached.method()` based on `@functools.lru_cache()`
- track mutable dependencies for cache invalidation
- fix cache clearing helper
- add behavioural spec in `tests/test_core/test_cache.py`

Adjustments in existing classes in this commit:

- define `__cache_state__()` for all subclasses of `Cached`
  (might require updates when adding further methods to the cache)
- define `@Cached.method(attrs=(...))` where appropriate
- remove obsolete class-specific caching mechanisms
- factor out cached worker: `Network.nsi_betweenness()`
- define helper: `Network.find_link_attribute()`
- define helper: `ClimateNetwork._weighted_metric()`

Adding this behaviour to classes without caching should be straightforward.
Classes that remain to be adjusted (still have a `clear_cache()` method
or subclass `Cached` without conforming to its protocoll):

- `ClimateNetwork` and its subclasses:
    The recursive interaction between `__init__()`, `_similarity_measure`
    and `_regenerate_network()` across the class hierarchy is very
    stateful and does not fit into the regular OO dependency pattern
    assumed by `Cached`. A redesign of this logic would be advisable,
    but is left for future work.
- `RecurrenceNetwork` and its subclasses, and `Surrogate`:
    Left as an exercise for the reader.
  • Loading branch information
ntfrgl committed Feb 21, 2024
1 parent 3148940 commit 93ab748
Show file tree
Hide file tree
Showing 26 changed files with 1,144 additions and 1,164 deletions.
148 changes: 38 additions & 110 deletions src/pyunicorn/climate/climate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,16 @@
Provides classes for generating and analyzing complex climate networks.
"""

#
# Import essential packages
#
from typing import Tuple, Hashable

# Import NumPy for the array object and fast numerics
import numpy as np
from numpy import random

from ..core import Data
from ..core.cache import Cached


#
# Define class ClimateData
#
class ClimateData(Data):
class ClimateData(Data, Cached):

"""
Encapsulates spatio-temporal climate data.
Expand Down Expand Up @@ -73,6 +68,9 @@ def __init__(self, observable, grid, time_cycle, anomalies=False,
:arg dict window: Spatio-temporal window to select a view on the data.
:arg int silence_level: The inverse level of verbosity of the object.
"""
self._mut_window = 0
"""mutation count"""

Data.__init__(self, observable=observable, grid=grid,
observable_name=observable_name,
observable_long_name=observable_long_name,
Expand All @@ -83,42 +81,22 @@ def __init__(self, observable, grid, time_cycle, anomalies=False,
"""(number (int)) - The annual cycle length of the data
(units of samples)."""

# Set flags
self._flag_phase_mean = False
self._phase_mean = None

self.data_source = ""

# If data are anomalies skip automatic calculation of anomalies
if anomalies:
self._flag_anomaly = True
self._anomaly = observable
else:
self._flag_anomaly = False
self.anomalies = anomalies

def __cache_state__(self) -> Tuple[Hashable, ...]:
# The following attributes are assumed immutable:
# (_full_observable)
return (self._mut_window,)

def __str__(self):
"""
Returns a string representation.
"""
return 'ClimateData:\n' + Data.__str__(self)

def clear_cache(self):
"""
Clean up cache.
Is reversible, since all cached information can be recalculated from
basic data.
"""
Data.clear_cache(self)

if self._flag_phase_mean:
del self._phase_mean
self._flag_phase_mean = False

if self._flag_anomaly:
del self._anomaly
self._flag_anomaly = False

#
# Define alternative constructors
#
Expand Down Expand Up @@ -304,7 +282,8 @@ def indices_selected_months(self, selected_months):
raise NotImplementedError("Currently only time cycles 12 and 360 \
are supported")

def _calculate_phase_mean(self):
@Cached.method(name="climatological mean values")
def phase_mean(self):
"""
Calculate mean values of observable for each phase of the annual cycle.
Expand All @@ -318,35 +297,6 @@ def _calculate_phase_mean(self):
:rtype: 2D Numpy array [cycle index, node index]
:return: the mean values of observable for each phase of the annual
cycle.
"""
if self.silence_level <= 1:
print("Calculating climatological mean values...")

# Get raw data
observable = self.observable()
# Get time cycle
time_cycle = self.time_cycle

# Get number of time series
N = observable.shape[1]

# Initialize
phase_mean = np.zeros((time_cycle, N))

# Calculate mean value for each day (month) on each node
for i in range(time_cycle):
phase_mean[i, :] = observable[i::time_cycle, :].mean(axis=0)

return phase_mean

def phase_mean(self):
"""
Return mean values of observable for each phase of the annual cycle.
For further comments, see :meth:`_calculate_phase_mean`.
.. note::
Only the currently selected spatio-temporal window is considered.
**Example:**
Expand All @@ -356,18 +306,19 @@ def phase_mean(self):
[ 0.6984, 0.1106, -0.6984, -0.1106, 0.6984, 0.1106],
[ 0.6984, -0.1106, -0.6984, 0.1106, 0.6984, -0.1106],
[ 0.63 , -0.321 , -0.63 , 0.321 , 0.63 , -0.321 ]])
:rtype: 2D Numpy array [cycle index, node index]
:return: the mean values of observable for each phase of the annual
cycle.
"""
if not self._flag_phase_mean:
self._phase_mean = self._calculate_phase_mean()
self._flag_phase_mean = True
observable = self.observable()
time_cycle = self.time_cycle
N = observable.shape[1]
phase_mean = np.zeros((time_cycle, N))

return self._phase_mean
# Calculate mean value for each day (month) on each node
for i in range(time_cycle):
phase_mean[i, :] = observable[i::time_cycle, :].mean(axis=0)
return phase_mean

def _calculate_anomaly(self):
@Cached.method(name="daily (monthly) anomaly values")
def anomaly(self):
"""
Calculate anomaly time series from observable.
Expand All @@ -380,53 +331,32 @@ def _calculate_anomaly(self):
:rtype: 2D Numpy array [time, node index]
:return: the anomalized time series.
**Example:**
>>> r(ClimateData.SmallTestData().anomaly()[:,0])
array([-0.5 , -0.321 , -0.1106, 0.1106, 0.321 ,
0.5 , 0.321 , 0.1106, -0.1106, -0.321 ])
"""
if self.silence_level <= 1:
print("Calculating daily (monthly) anomaly values...")
# If data are anomalies skip automatic calculation of anomalies
if self.anomalies:
return self._full_observable

# Get raw data
observable = self.observable()
# Get time cycle
time_cycle = self.time_cycle
# Initialize array
anomaly = np.zeros(observable.shape)

# Thanks to Jakob Runge
for i in range(time_cycle):
sample = observable[i::time_cycle, :]
anomaly[i::time_cycle, :] = sample - sample.mean(axis=0)

return anomaly

def anomaly(self):
"""
Return anomaly time series from observable.
For further comments, see :meth:`_calculate_anomaly`.
.. note::
Only the currently selected spatio-temporal window is considered.
**Example:**
>>> r(ClimateData.SmallTestData().anomaly()[:,0])
array([-0.5 , -0.321 , -0.1106, 0.1106, 0.321 ,
0.5 , 0.321 , 0.1106, -0.1106, -0.321 ])
:rtype: 2D Numpy array [time, node index]
:return: the anomalized time series.
"""
if not self._flag_anomaly:
self._anomaly = self._calculate_anomaly()
self._flag_anomaly = True

return self._anomaly

def anomaly_selected_months(self, selected_months):
"""
Return anomaly time series from observable for selected months.
For further comments, see :meth:`_calculate_anomaly`.
For further comments, see :meth:`anomaly`.
.. note::
Only the currently selected spatio-temporal window is considered.
Expand Down Expand Up @@ -507,9 +437,8 @@ def set_window(self, window):
:arg window: The spatio-temporal window to select a view on the data.
"""
Data.set_window(self, window)

self._flag_phase_mean = False
self._flag_anomaly = False
# invalidate cache
self._mut_window += 1

def set_global_window(self):
"""
Expand All @@ -532,6 +461,5 @@ def set_global_window(self):
array([ 0., 5., 10., 15., 20., 25.], dtype=float32)
"""
Data.set_global_window(self)

self._flag_phase_mean = False
self._flag_anomaly = False
# invalidate cache
self._mut_window += 1
68 changes: 29 additions & 39 deletions src/pyunicorn/climate/climate_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,15 @@
Provides classes for generating and analyzing complex climate networks.
"""

#
# Import essential packages
#
from typing import Tuple, Hashable, Callable

# Import NumPy for the array object and fast numerics
import numpy as np

# Import iGraph for high performance graph theory tools written in pure ANSI-C
import igraph

# Import GeoNetwork and GeoGrid classes
from ..core.cache import Cached
from ..core import GeoNetwork, GeoGrid
from ..core.network import cached_const


#
# Define class ClimateNetwork
#

class ClimateNetwork(GeoNetwork):

"""
Expand All @@ -50,9 +40,9 @@ class ClimateNetwork(GeoNetwork):
# Definitions of internal methods
#

def __init__(self, grid, similarity_measure, threshold=None,
link_density=None, non_local=False, directed=False,
node_weight_type="surface", silence_level=0):
def __init__(self, grid: GeoGrid, similarity_measure: np.ndarray,
threshold=None, link_density=None, non_local=False,
directed=False, node_weight_type="surface", silence_level=0):
"""
Initialize an instance of :class:`ClimateNetwork`.
Expand Down Expand Up @@ -81,10 +71,17 @@ def __init__(self, grid, similarity_measure, threshold=None,
:arg int silence_level: The inverse level of verbosity of the object.
"""
# Initialize
self.grid = grid
assert isinstance(grid, GeoGrid)
self.grid: GeoGrid = grid
self.directed = directed
self.silence_level = silence_level

# mutation count
if not hasattr(self, "_mut_clim"):
self._mut_clim: int = 0
else:
self._mut_clim += 1

# FIXME: Is taking the absolute value by default OK?
self._similarity_measure = np.abs(similarity_measure.astype("float32"))
self._non_local = non_local
Expand All @@ -105,6 +102,12 @@ def __init__(self, grid, similarity_measure, threshold=None,
node_weight_type=self.node_weight_type,
silence_level=self.silence_level)

def __cache_state__(self) -> Tuple[Hashable, ...]:
return GeoNetwork.__cache_state__(self) + (self._mut_clim,)

def __rec_cache_state__(self) -> Tuple[object, ...]:
return (self.grid,)

def __str__(self):
"""
Return a string representation of the ClimateNetwork object.
Expand All @@ -126,26 +129,9 @@ def __str__(self):
f'Threshold: {self.threshold()}\n' +
f'Local connections filtered out: {self.non_local()}')

def clear_cache(self, irreversible=False):
"""
Clean up cache.
If irreversible=True, the network cannot be recalculated using a
different threshold, or link density.
:arg bool irreversible: The irreversibility of clearing the cache.
"""
GeoNetwork.clear_cache(self)

if irreversible:
try:
del self._similarity_measure
except AttributeError:
pass

def _regenerate_network(self):
"""
Regenerate the current climate network according to new similarity
Regenerate the current climate network according to a new similarity
measure.
"""
ClimateNetwork.__init__(self, grid=self.data.grid,
Expand Down Expand Up @@ -289,9 +275,8 @@ def Load(filename, fileformat=None, silence_level=0, *args, **kwds):
# Overwrite igraph Graph object in Network instance to restore link
# attributes/weights
net.graph = graph
# Restore link attributes/weights
net.clear_paths_cache()

# invalidate cache
net._mut_la += 1
return net

#
Expand Down Expand Up @@ -626,7 +611,7 @@ def set_link_density(self, link_density):
threshold = self.threshold_from_link_density(link_density)
self.set_threshold(threshold)

@cached_const('base', 'correlation_distance')
@Cached.method()
def correlation_distance(self):
"""
Return correlation weighted distances between nodes.
Expand Down Expand Up @@ -658,7 +643,7 @@ def correlation_distance(self):
"""
return self.similarity_measure() * self.grid.angular_distance()

@cached_const('base', 'inv_correlation_distance')
@Cached.method()
def inv_correlation_distance(self):
"""
Return correlation weighted distances between nodes.
Expand Down Expand Up @@ -719,3 +704,8 @@ def local_correlation_distance_weighted_vulnerability(self):
"""
self.inv_correlation_distance()
return self.local_vulnerability('inv_correlation_distance')

def _weighted_metric(self, attr: str, calc: Callable, metric: str):
if not self.find_link_attribute(attr):
self.set_link_attribute(attr, calc())
return getattr(self, metric)(attr)
Loading

0 comments on commit 93ab748

Please sign in to comment.