Skip to content

Commit

Permalink
Merge pull request #62 from NNPDF/240307_improve_test_executor
Browse files Browse the repository at this point in the history
Improve test executor readability
  • Loading branch information
comane authored Mar 11, 2024
2 parents 93b94ad + f28609f commit 02acde3
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 44 deletions.
2 changes: 2 additions & 0 deletions src/reportengine/resourcebuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def execute_parallel(self, scheduler=None):
# gather futures once all jobs have been submitted
self.gather_results(leaf_callspecs, client)

return client


def set_future(self, future, callspec):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/reportengine/tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ def test_collect(self):
builder.resolve_fuzzytargets()
d = namespaces.resolve(builder.rootns, [('lists',1)])
assert d['restaurant_collect'] == list("123")
builder.execute_parallel()
client = builder.execute_parallel()
# since it is using dask it returns a future
assert namespaces.resolve(builder.rootns, ('UK',))['score'].result() == -1
# close the client
client.close()

def test_collect_raises(self):
with self.assertRaises(TypeError):
Expand Down
114 changes: 71 additions & 43 deletions src/reportengine/tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
"""
Created on Fri Nov 13 22:51:32 2015
Demonstrates a simple usage of the reportengine module for building
and executing a Directed Acyclic Graph (DAG) of functions.
DAG is executed in parallel and in sequence.
@author: zah
"""

Expand All @@ -11,87 +15,111 @@
from reportengine.dag import DAG
from reportengine.utils import ChainMap
from reportengine import namespaces
from reportengine.resourcebuilder import (ResourceExecutor, CallSpec)
from reportengine.resourcebuilder import ResourceExecutor, CallSpec

"""
Define some simple functions that will be used as nodes in the DAG.
"""

def f(param):
print("Executing f")

def node_1(param):
print("Executing node_1")
time.sleep(0.1)
return "fresult: %s" % param
return "node_1_result: %s" % param

def g(fresult):
print("Executing g")
time.sleep(0.2)
return fresult*2

def h(fresult):
print("Executing h")
def node_2_1(node_1_result):
print("Executing node_2_1")
time.sleep(0.2)
return fresult*3
return node_1_result * 2

def m(gresult, hresult, param=None):
print("executing m")
return (gresult+hresult)*(param//2)

def n(mresult):
return mresult
def node_2_2(node_1_result):
print("Executing node_2_2")
time.sleep(0.2)
return node_1_result * 3


def o(mresult):
return mresult*2
def node_3(node_2_1_result, node_2_2_result, param=None):
print("Executing node_3")
return (node_2_1_result + node_2_2_result) * (param // 2)

def p(mresult):
return mresult*3

class TestResourceExecutor(unittest.TestCase, ResourceExecutor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
ResourceExecutor.__init__(self, None, None)

def setUp(self):
self.rootns = ChainMap({'param':4, 'inner': {}})
"""
Creates a simple DAG of functions with diamond shape.
node_1
/ \
node_2_1 node_2_2
\ /
node_3
"""
self.rootns = ChainMap({"param": 4, "inner": {}})

def nsspec(x, beginning=()):
ns = namespaces.resolve(self.rootns, beginning)
default_label = '_default' + str(x)
default_label = "_default" + str(x)
namespaces.push_nslevel(ns, default_label)
return beginning + (default_label,)

self.graph = DAG()

fcall = CallSpec(f, ('param',), 'fresult',
nsspec(f))

gcall = CallSpec(g, ('fresult',), 'gresult',
nsspec(g))
node_1_call = CallSpec(node_1, ("param",), "node_1_result", nsspec(node_1))

hcall = CallSpec(h, ('fresult',), 'hresult',
nsspec(h))
node_2_1_call = CallSpec(
node_2_1, ("node_1_result",), "node_2_1_result", nsspec(node_2_1)
)

mcall = CallSpec(m, ('gresult','hresult','param'), 'mresult',
nsspec(m))
node_2_2_call = CallSpec(
node_2_2, ("node_1_result",), "node_2_2_result", nsspec(node_2_2)
)

node_3_call = CallSpec(
node_3,
("node_2_1_result", "node_2_2_result", "param"),
"node_3_result",
nsspec(node_3),
)


self.graph.add_node(fcall)
self.graph.add_node(gcall, inputs={fcall})
self.graph.add_node(hcall, inputs={fcall})
self.graph.add_node(mcall, inputs={gcall, hcall})

self.graph.add_node(node_1_call)
self.graph.add_node(node_2_1_call, inputs={node_1_call})
self.graph.add_node(node_2_2_call, inputs={node_1_call})
self.graph.add_node(node_3_call, inputs={node_2_1_call, node_2_2_call})

def _test_ns(self, promise=False):
mresult = 'fresult: 4'*10
"""
Asserts that the namespace contains the expected results.
"""
node_3_result = "node_1_result: 4" * 10
namespace = self.rootns
if promise:
self.assertEqual(namespace['mresult'].result(), mresult)
self.assertEqual(namespace["node_3_result"].result(), node_3_result)
else:
self.assertEqual(namespace['mresult'], mresult)

self.assertEqual(namespace["node_3_result"], node_3_result)

def test_seq_execute(self):
"""
This test will execute the DAG in sequence.
"""
self.execute_sequential()
self._test_ns()

def test_parallel_execute(self):
self.execute_parallel()
"""
This test will execute the DAG in parallel, using
dask distributed scheduler.
"""
client = self.execute_parallel()
self._test_ns(promise=True)
client.close()


if __name__ =='__main__':
if __name__ == "__main__":
unittest.main()

0 comments on commit 02acde3

Please sign in to comment.