-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest_encode.py
62 lines (49 loc) · 1.86 KB
/
test_encode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from typing import Callable
import numpy as np
import pytest
from gfloat import decode_float, encode_float, encode_ndarray
from gfloat.formats import *
@pytest.mark.parametrize("fi", all_formats)
def test_encode(fi: FormatInfo) -> None:
dec = lambda v: decode_float(fi, v).fval
if fi.bits <= 8:
step = 1
elif fi.bits <= 16:
step = 13
elif fi.bits <= 32:
step = 73013
elif fi.bits <= 64:
step = (73013 << 32) + 39
for i in range(0, 2**fi.bits, step):
fv = decode_float(fi, i)
code = encode_float(fi, fv.fval)
assert (i == code) or (np.isnan(fv.fval) and code == fi.code_of_nan)
fv2 = decode_float(fi, code)
np.testing.assert_equal(fv2.fval, fv.fval)
codes = np.arange(0, 2**fi.bits, step, dtype=np.uint64)
fvals = np.array([decode_float(fi, int(i)).fval for i in codes])
enc_codes = encode_ndarray(fi, fvals)
if fi.num_nans == 0:
assert not np.any(np.isnan(fvals))
expected_codes = codes
else:
expected_codes = np.where(np.isnan(fvals), fi.code_of_nan, codes)
np.testing.assert_equal(enc_codes, expected_codes)
@pytest.mark.parametrize("fi", all_formats)
@pytest.mark.parametrize("enc", (encode_float, encode_ndarray))
def test_encode_edges(fi: FormatInfo, enc: Callable) -> None:
if enc == encode_ndarray:
enc = lambda fi, x: encode_ndarray(fi, np.array([x])).item()
assert enc(fi, fi.max) == fi.code_of_max
assert enc(fi, fi.max * 1.25) == (
fi.code_of_posinf
if fi.has_infs
else fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max
)
if fi.is_signed:
assert enc(fi, fi.min * 1.25) == (
fi.code_of_neginf
if fi.has_infs
else fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min
)