Skip to content

Commit

Permalink
Improve unscaling test
Browse files Browse the repository at this point in the history
  • Loading branch information
samaloney committed Dec 7, 2023
1 parent c93db4c commit 71771aa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
1 change: 1 addition & 0 deletions stixcore/products/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def unscale_triggers(scaled_triggers, *, integration, detector_masks, ssid, fact
tuple
Unscaled triggers, Scaling Variance
"""
# TODO extract this to a separate function and test
detector_to_trigger_group_map = np.array(
[[1, 6, 5, 12, 14, 10, 8, 3, 31, 26, 22, 20, 18, 17, 24, 29],
[2, 7, 11, 13, 15, 16, 9, 4, 32, 27, 28, 21, 19, 23, 25, 30]]).T
Expand Down
31 changes: 21 additions & 10 deletions stixcore/products/tests/test_scaling.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,34 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import astropy.units as u

from stixcore.products.common import unscale_triggers
from stixcore.time import SCETimeDelta


def test_unscale():
@pytest.mark.parametrize('factor', [20, 30, 40])
@pytest.mark.parametrize('n_int', [10, 20, 30])
@pytest.mark.parametrize('ssid', [21, 22, 23, 24])
def test_unscale(factor, n_int, ssid):
dmask = np.ones((1, 32))
duration = 10*u.s
triggers_in = np.full((1, 16), 123456)
duration = SCETimeDelta(n_int * u.ds)

n_int = 100
n_groups = 1
if ssid == 24:
n_groups = 16

triggers_scaled = np.floor(triggers_in / (n_int * n_groups * 30))
trigger_unscaled_var = 0.5 * n_int * n_groups * 30
triggers_unscaled = triggers_scaled * n_int * n_groups * 30 + trigger_unscaled_var
norm = n_groups * n_int * factor

triggers_in = np.tile(np.arange(255*n_int*factor).reshape(-1, 1), 16, )
if ssid == 24:
triggers_in = triggers_in.sum(axis=1)

triggers_scaled = np.floor(triggers_in / norm)
trigger_unscaled_var = 0.5 * norm
triggers_unscaled = triggers_scaled * norm + trigger_unscaled_var
triggers_out, trigger_out_var = unscale_triggers(triggers_scaled, integration=duration,
detector_masks=dmask, ssid=21)
assert_array_equal(triggers_out[0, 0], triggers_unscaled)
assert_array_equal(trigger_out_var[0, 0], trigger_unscaled_var**2)
detector_masks=dmask, ssid=ssid, factor=factor)
assert_array_equal(triggers_out, triggers_unscaled)
assert_array_equal(trigger_out_var, trigger_unscaled_var**2)

0 comments on commit 71771aa

Please sign in to comment.