Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternate cached array implementation #219

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 42 additions & 117 deletions daskms/optimisation.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,44 @@
# -*- coding: utf-8 -*-

from threading import Lock
import uuid
from weakref import WeakValueDictionary, WeakKeyDictionary
from itertools import product

import dask
import dask.array as da
from dask.core import flatten, _execute_task
from dask.core import flatten, get
from dask.highlevelgraph import HighLevelGraph
from dask.optimization import cull, inline
import numpy as np


_key_cache = WeakValueDictionary()
_key_cache_lock = Lock()


class KeyMetaClass(type):
"""
Ensures that Key identities are the same,
given the same constructor arguments
"""
def __call__(cls, key):
try:
return _key_cache[key]
except KeyError:
pass

with _key_cache_lock:
try:
return _key_cache[key]
except KeyError:
_key_cache[key] = instance = type.__call__(cls, key)
return instance


class Key(metaclass=KeyMetaClass):
"""
Suitable for storing a tuple
(or other dask key type) in a WeakKeyDictionary.
Uniques of key identity guaranteed by KeyMetaClass
"""
__slots__ = ("key", "__weakref__")

def __init__(self, key):
self.key = key

def __hash__(self):
return hash(self.key)

def __repr__(self):
return f"Key{self.key}"
class GraphCache:
def __init__(self, collection):
self.collection = collection
self.lock = Lock()
self.cache = {}

def __reduce__(self):
return (Key, (self.key,))

__str__ = __repr__

return (GraphCache, (self.collection,))

def cache_entry(cache, key, task):
with cache.lock:
def __call__(self, block_id):
try:
return cache.cache[key]
except KeyError:
cache.cache[key] = value = _execute_task(task, {})
return value
dsk = self.dsk
except AttributeError:
with self.lock:
try:
dsk = self.dsk
except AttributeError:
self.dsk = dsk = dict(self.collection.__dask_graph__())

key = (self.collection.name,) + block_id

_array_cache_cache = WeakValueDictionary()
_array_cache_lock = Lock()


class ArrayCacheMetaClass(type):
"""
Ensures that Array Cache identities are the same,
given the same constructor arguments
"""
def __call__(cls, token):
key = (cls, token)

try:
return _array_cache_cache[key]
except KeyError:
pass

with _array_cache_lock:
with self.lock:
try:
return _array_cache_cache[key]
return self.cache[key]
except KeyError:
instance = type.__call__(cls, token)
_array_cache_cache[key] = instance
return instance


class ArrayCache(metaclass=ArrayCacheMetaClass):
"""
Thread-safe array data cache. token makes this picklable.

Cached on a WeakKeyDictionary with ``Key`` objects.
"""

def __init__(self, token):
self.token = token
self.cache = WeakKeyDictionary()
self.lock = Lock()

def __reduce__(self):
return (ArrayCache, (self.token,))

def __repr__(self):
return f"ArrayCache[{self.token}]"
return get(dsk, key, self.cache)


def cached_array(array, token=None):
Expand All @@ -123,33 +54,27 @@ def cached_array(array, token=None):
----------
array : :class:`dask.array.Array`
dask array to cache.
token : optional, str
A unique token for identifying the internal cache.
If None, it will be automatically generated.
"""
dsk = dict(array.__dask_graph__())
keys = set(flatten(array.__dask_keys__()))

if token is None:
token = uuid.uuid4().hex

# Inline + cull everything except the current array
inline_keys = set(dsk.keys() - keys)
dsk2 = inline(dsk, inline_keys, inline_constants=True)
dsk3, _ = cull(dsk2, keys)

# Create a cache used to store array values
cache = ArrayCache(token)

assert len(dsk3) == len(keys)

for k in keys:
dsk3[k] = (cache_entry, cache, Key(k), dsk3.pop(k))

graph = HighLevelGraph.from_collections(array.name, dsk3, [])

return da.Array(graph, array.name, array.chunks, array.dtype)

assert isinstance(array, da.Array)
name = f"block-id-{array.name}"
dsk = {(name,) + block_id: block_id
for block_id in product(*(range(len(c)) for c in array.chunks))}
assert all(all(isinstance(e, int) for e in bid) for bid in dsk.values())
block_id_array = da.Array(dsk, name,
chunks=tuple((1,)*len(c) for c in array.chunks),
dtype=np.object_)

assert array.ndim == block_id_array.ndim
idx = list(range(array.ndim))
adjust_chunks = dict(zip(idx, array.chunks))
cache = GraphCache(array)
token = f"GraphCache-{dask.base.tokenize(cache, block_id_array)}"

return da.blockwise(cache, idx,
block_id_array, idx,
adjust_chunks=adjust_chunks,
meta=array._meta,
name=token)

def inlined_array(a, inline_arrays=None):
""" Flatten underlying graph """
Expand Down
58 changes: 15 additions & 43 deletions daskms/tests/test_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,9 @@

from daskms import xds_from_ms
from daskms.optimisation import (inlined_array,
cached_array,
ArrayCache,
Key,
_key_cache,
_array_cache_cache)
cached_array)


def test_optimisation_identity():
# Test identity
assert Key((0, 1, 2)) is Key((0, 1, 2))
assert ArrayCache(1) is ArrayCache(1)

# Test pickling
assert pickle.loads(pickle.dumps(Key((0, 1, 2)))) is Key((0, 1, 2))
assert pickle.loads(pickle.dumps(ArrayCache(1))) is ArrayCache(1)


def test_inlined_array():
A = da.ones((10, 10), chunks=(2, 2), dtype=np.float64)
Expand Down Expand Up @@ -86,45 +73,30 @@ def test_inlined_array():
assert_array_equal(D, E)


def test_blah_cached():
A = da.arange(5, chunks=1)

def f(a):
print(f"Got {a}")
return a

B = da.blockwise(f, 'a', A, 'a', meta=A._meta)
B = cached_array(inlined_array(B))

C = da.ones(20, chunks=5)[None, :] * B[:, None]
C.compute()


def test_cached_array(ms):
ds = xds_from_ms(ms, group_cols=[], chunks={'row': 1, 'chan': 4})[0]

data = ds.DATA.data
cached_data = cached_array(data)
assert_array_almost_equal(cached_data, data)

# 2 x row blocks + row x chan x corr blocks
assert len(_key_cache) == data.numblocks[0] * 2 + data.npartitions
# rows, row runs and data array cache's
assert len(_array_cache_cache) == 3

# Pickling works
pickled_data = pickle.loads(pickle.dumps(cached_data))
assert_array_almost_equal(pickled_data, data)

# Same underlying caching is re-used
# 2 x row blocks + row x chan x corr blocks
assert len(_key_cache) == data.numblocks[0] * 2 + data.npartitions
# rows, row runs and data array cache's
assert len(_array_cache_cache) == 3

del pickled_data, cached_data, data, ds
gc.collect()

assert len(_key_cache) == 0
assert len(_array_cache_cache) == 0


@pytest.mark.parametrize("token", ["0xdeadbeaf", None])
def test_cached_data_token(token):
zeros = da.zeros(1000, chunks=100)
carray = cached_array(zeros, token)

dsk = dict(carray.__dask_graph__())
k, v = dsk.popitem()
cache = v[1]

if token is None:
assert cache.token is not None
else:
assert cache.token == token