Skip to content

Commit

Permalink
Merge branch 'main' into roi
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonHeybrock authored Feb 7, 2025
2 parents 5e25b53 + b7cda71 commit 1aa5e92
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 5 deletions.
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ python-dateutil==2.9.0.post0
# scippnexus
sciline==24.10.0
# via -r base.in
scipp==25.1.0
scipp==25.2.0
# via
# -r base.in
# scippneutron
Expand Down
2 changes: 1 addition & 1 deletion requirements/basetest.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ python-dateutil==2.9.0.post0
# via matplotlib
requests==2.32.3
# via pooch
scipp==25.1.0
scipp==25.2.0
# via tof
scipy==1.15.1
# via
Expand Down
56 changes: 53 additions & 3 deletions src/ess/reduce/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import deepcopy
from typing import Any, Generic, TypeVar

import networkx as nx
Expand All @@ -29,6 +30,8 @@ def maybe_hist(value: T) -> T:
:
Histogram.
"""
if not isinstance(value, sc.Variable | sc.DataArray):
return value
return value if value.bins is None else value.hist()


Expand Down Expand Up @@ -90,11 +93,11 @@ def __init__(self, **kwargs: Any) -> None:

@property
def value(self) -> T:
return self._value.copy()
return deepcopy(self._value)

def _do_push(self, value: T) -> None:
if self._value is None:
self._value = value.copy()
self._value = deepcopy(value)
else:
self._value += value

Expand Down Expand Up @@ -146,6 +149,7 @@ def __init__(
target_keys: tuple[sciline.typing.Key, ...],
accumulators: dict[sciline.typing.Key, Accumulator, Callable[..., Accumulator]]
| tuple[sciline.typing.Key, ...],
allow_bypass: bool = False,
) -> None:
"""
Create a stream processor.
Expand All @@ -163,6 +167,12 @@ def __init__(
passed, :py:class:`EternalAccumulator` is used for all keys. Otherwise, a
dict mapping keys to accumulator instances can be passed. If a dict value is
a callable, base_workflow.bind_and_call(value) is used to make an instance.
allow_bypass:
If True, allow bypassing accumulators for keys that are not in the
accumulators dict. This is useful for dynamic keys that are not "terminated"
in any accumulator. USE WITH CARE! This will lead to incorrect results
unless the values for these keys are valid for all chunks comprised in the
final accumulators at the point where :py:meth:`finalize` is called.
"""
workflow = sciline.Pipeline()
for key in target_keys:
Expand Down Expand Up @@ -201,19 +211,59 @@ def __init__(
for key, value in self._accumulators.items()
}
self._target_keys = target_keys
self._allow_bypass = allow_bypass

def add_chunk(
self, chunks: dict[sciline.typing.Key, Any]
) -> dict[sciline.typing.Key, Any]:
"""
Legacy interface for accumulating values from chunks and finalizing the result.
It is recommended to use :py:meth:`accumulate` and :py:meth:`finalize` instead.
Parameters
----------
chunks:
Chunks to be processed.
Returns
-------
:
Finalized result.
"""
self.accumulate(chunks)
return self.finalize()

def accumulate(self, chunks: dict[sciline.typing.Key, Any]) -> None:
"""
Accumulate values from chunks without finalizing the result.
Parameters
----------
chunks:
Chunks to be processed.
"""
for key, value in chunks.items():
self._process_chunk_workflow[key] = value
# There can be dynamic keys that do not "terminate" in any accumulator. In
# that case, we need to make sure they can be and are used when computing
# the target keys.
self._finalize_workflow[key] = value
if self._allow_bypass:
self._finalize_workflow[key] = value
to_accumulate = self._process_chunk_workflow.compute(self._accumulators)
for key, processed in to_accumulate.items():
self._accumulators[key].push(processed)

def finalize(self) -> dict[sciline.typing.Key, Any]:
"""
Get the final result by computing the target keys based on accumulated values.
Returns
-------
:
Finalized result.
"""
for key in self._accumulators:
self._finalize_workflow[key] = self._accumulators[key].value
return self._finalize_workflow.compute(self._target_keys)

Expand Down
100 changes: 100 additions & 0 deletions tests/streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import NewType

import pytest
import sciline
import scipp as sc

Expand Down Expand Up @@ -214,6 +215,7 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(),
allow_bypass=True,
)
result = streaming_wf.add_chunk({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0))
Expand All @@ -222,3 +224,101 @@ def test_StreamProcess_with_zero_accumulators_for_buffered_workflow_calls() -> N
result = streaming_wf.add_chunk({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 6.0))
assert make_static_a.call_count == 1


def test_StreamProcessor_with_bypass() -> None:
def _make_static_a() -> StaticA:
_make_static_a.call_count += 1
return StaticA(2.0)

_make_static_a.call_count = 0

base_workflow = sciline.Pipeline(
(_make_static_a, make_accum_a, make_accum_b, make_target)
)
orig_workflow = base_workflow.copy()

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA,), # Note: No AccumB
allow_bypass=True,
)
streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 1.0 / 4.0))
streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)})
result = streaming_wf.finalize()
# Note denominator is 5, not 9
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 5.0))
streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
result = streaming_wf.finalize()
# Note denominator is 6, not 15
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 6.0))
assert _make_static_a.call_count == 1

# Consistency check: Run the original workflow with the same inputs, all at once
orig_workflow[DynamicA] = sc.scalar(1 + 2 + 3)
orig_workflow[DynamicB] = sc.scalar(6)
expected = orig_workflow.compute(Target)
assert sc.identical(expected, result[Target])


def test_StreamProcessor_without_bypass_raises() -> None:
def _make_static_a() -> StaticA:
_make_static_a.call_count += 1
return StaticA(2.0)

_make_static_a.call_count = 0

base_workflow = sciline.Pipeline(
(_make_static_a, make_accum_a, make_accum_b, make_target)
)

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA,), # Note: No AccumB
)
streaming_wf.accumulate({DynamicA: 1, DynamicB: 4})
# Sciline passes `None` to the provider that needs AccumB.
with pytest.raises(TypeError, match='unsupported operand type'):
_ = streaming_wf.finalize()


def test_StreamProcessor_calls_providers_after_accumulators_only_when_finalizing() -> (
None
):
def _make_target(accum_a: AccumA, accum_b: AccumB) -> Target:
_make_target.call_count += 1
return Target(accum_a / accum_b)

_make_target.call_count = 0

base_workflow = sciline.Pipeline(
(make_accum_a, make_accum_b, _make_target), params={StaticA: 2.0}
)

streaming_wf = streaming.StreamProcessor(
base_workflow=base_workflow,
dynamic_keys=(DynamicA, DynamicB),
target_keys=(Target,),
accumulators=(AccumA, AccumB),
)
streaming_wf.accumulate({DynamicA: sc.scalar(1), DynamicB: sc.scalar(4)})
streaming_wf.accumulate({DynamicA: sc.scalar(2), DynamicB: sc.scalar(5)})
assert _make_target.call_count == 0
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 3.0 / 9.0))
assert _make_target.call_count == 1
streaming_wf.accumulate({DynamicA: sc.scalar(3), DynamicB: sc.scalar(6)})
assert _make_target.call_count == 1
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0))
assert _make_target.call_count == 2
result = streaming_wf.finalize()
assert sc.identical(result[Target], sc.scalar(2 * 6.0 / 15.0))
# Outputs are not cached.
assert _make_target.call_count == 3

0 comments on commit 1aa5e92

Please sign in to comment.