Skip to content

Commit

Permalink
Merge pull request #58 from evalott100/make_conversion_to_string_enum…
Browse files Browse the repository at this point in the history
…s_optional

Made conversion on enum fields to List[str] optional
  • Loading branch information
evalott100 authored Oct 11, 2023
2 parents 3e086c9 + 659a25f commit 4cc1e30
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 77 deletions.
33 changes: 20 additions & 13 deletions pandablocks/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Sequence, Union, cast
from typing import Dict, Iterable, List, Union, cast

import numpy as np
import numpy.typing as npt
Expand All @@ -8,12 +8,14 @@
UnpackedArray = Union[
npt.NDArray[np.int32],
npt.NDArray[np.uint32],
Sequence[str],
List[str],
]


def words_to_table(
words: Iterable[str], table_field_info: TableFieldInfo
words: Iterable[str],
table_field_info: TableFieldInfo,
convert_enum_indices: bool = False,
) -> Dict[str, UnpackedArray]:
"""Unpacks the given `packed` data based on the fields provided.
Returns the unpacked data in {column_name: column_data} column-indexed format
Expand All @@ -23,6 +25,8 @@ def words_to_table(
expected to be the string representation of a uint32.
table_fields_info: The info for tables, containing the number of words per row,
and the bit information for fields.
convert_enum_indices: If True, convert all enum values to their string
representation. Otherwise return enums as integer values
Returns:
unpacked: A dict containing record information, where keys are field names
and values are numpy arrays or a sequence of strings of record values
Expand Down Expand Up @@ -56,7 +60,8 @@ def words_to_table(
if field_info.subtype == "int":
# First convert from 2's complement to offset, then add in offset.
packing_value = (value ^ (1 << (bit_length - 1))) + (-1 << (bit_length - 1))
elif field_info.labels:
elif field_info.subtype == "enum" and convert_enum_indices:
assert field_info.labels, f"Enum field {field_name} has no labels"
packing_value = [field_info.labels[x] for x in value]
else:
packing_value = value
Expand All @@ -67,7 +72,7 @@ def words_to_table(


def table_to_words(
table: Dict[str, Iterable], table_field_info: TableFieldInfo
table: Dict[str, UnpackedArray], table_field_info: TableFieldInfo
) -> List[str]:
"""Convert records based on the field definitions into the format PandA expects
for table writes.
Expand All @@ -88,18 +93,19 @@ def table_to_words(

for column_name, column in table.items():
field_details = table_field_info.fields[column_name]
if field_details.labels:
# Must convert the list of ints into strings
column = [field_details.labels.index(x) for x in column]

# PandA always handles tables in uint32 format
column_value = np.array(column, dtype=np.uint32)
if field_details.labels and len(column) and isinstance(column[0], str):
# Must convert the list of strings to list of ints
column_value = np.array(
[field_details.labels.index(x) for x in column], dtype=np.uint32
)
else:
# PandA always handles tables in uint32 format
column_value = np.array(column, dtype=np.uint32)

if packed is None:
# Create 1-D array sufficiently long to exactly hold the entire table, cast
# to prevent type error, this will still work if column is another iterable
# e.g numpy array
column = cast(List, column)
packed = np.zeros((len(column), row_words), dtype=np.uint32)
else:
assert len(packed) == len(column), (
Expand All @@ -117,7 +123,8 @@ def table_to_words(

# Slice to get the column to apply the values to.
# bit shift the value to the relevant bits of the word
packed[:, word_offset] |= column_value << bit_offset

packed[:, word_offset] |= cast(np.unsignedinteger, column_value) << bit_offset

assert isinstance(packed, np.ndarray), "Table has no columns" # Squash mypy warning

Expand Down
122 changes: 58 additions & 64 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, OrderedDict
from typing import Dict, List, OrderedDict

import numpy as np
import pytest
Expand Down Expand Up @@ -157,32 +157,34 @@ def table_field_info(table_fields) -> TableFieldInfo:


@pytest.fixture
def table_1() -> OrderedDict[str, Iterable]:
def table_1_np_arrays() -> OrderedDict[str, UnpackedArray]:
# Intentionally not in panda order. Whatever types the np arrays are,
# the outputs from words_to_table will be uint32 or int32.
return OrderedDict(
{
"REPEATS": [5, 0, 50000],
"REPEATS": np.array([5, 0, 50000], dtype=np.uint32),
"TRIGGER": ["Immediate", "BITC=1", "Immediate"],
"POSITION": [-5, 678, 0],
"TIME1": [100, 0, 9],
"OUTA1": [0, 1, 1],
"OUTB1": [0, 0, 1],
"OUTC1": [0, 0, 1],
"OUTD1": [1, 0, 1],
"OUTE1": [0, 0, 1],
"OUTF1": [1, 0, 1],
"TIME2": [0, 55, 9999],
"OUTA2": [0, 0, 1],
"OUTB2": [0, 0, 1],
"OUTC2": [1, 1, 1],
"OUTD2": [0, 0, 1],
"OUTE2": [0, 0, 1],
"OUTF2": [1, 0, 1],
"POSITION": np.array([-5, 678, 0], dtype=np.int32),
"TIME1": np.array([100, 0, 9], dtype=np.uint32),
"OUTA1": np.array([0, 1, 1], dtype=np.uint8),
"OUTB1": np.array([0, 0, 1], dtype=np.uint8),
"OUTD1": np.array([1, 0, 1], dtype=np.uint8),
"OUTE1": np.array([0, 0, 1], dtype=np.uint8),
"OUTC1": np.array([0, 0, 1], dtype=np.uint8),
"OUTF1": np.array([1, 0, 1], dtype=np.uint8),
"TIME2": np.array([0, 55, 9999], dtype=np.uint32),
"OUTA2": np.array([0, 0, 1], dtype=np.uint8),
"OUTB2": np.array([0, 0, 1], dtype=np.uint8),
"OUTC2": np.array([1, 1, 1], dtype=np.uint8),
"OUTD2": np.array([0, 0, 1], dtype=np.uint8),
"OUTE2": np.array([0, 0, 1], dtype=np.uint8),
"OUTF2": np.array([1, 0, 1], dtype=np.uint8),
}
)


@pytest.fixture
def table_1_np_arrays() -> OrderedDict[str, Iterable]:
def table_1_np_arrays_int_enums() -> OrderedDict[str, UnpackedArray]:
# Intentionally not in panda order. Whatever types the np arrays are,
# the outputs from words_to_table will be uint32 or int32.
return OrderedDict(
Expand All @@ -203,32 +205,7 @@ def table_1_np_arrays() -> OrderedDict[str, Iterable]:
"OUTD2": np.array([0, 0, 1], dtype=np.uint8),
"OUTE2": np.array([0, 0, 1], dtype=np.uint8),
"OUTF2": np.array([1, 0, 1], dtype=np.uint8),
"TRIGGER": np.array(["Immediate", "BITC=1", "Immediate"], dtype="<U9"),
}
)


@pytest.fixture
def table_1_not_in_panda_order() -> OrderedDict[str, Iterable]:
return OrderedDict(
{
"REPEATS": [5, 0, 50000],
"TRIGGER": ["Immediate", "BITC=1", "Immediate"],
"POSITION": [-5, 678, 0],
"TIME1": [100, 0, 9],
"OUTA1": [0, 1, 1],
"OUTB1": [0, 0, 1],
"OUTC1": [0, 0, 1],
"OUTD1": [1, 0, 1],
"OUTF1": [1, 0, 1],
"OUTE1": [0, 0, 1],
"TIME2": [0, 55, 9999],
"OUTA2": [0, 0, 1],
"OUTC2": [1, 1, 1],
"OUTB2": [0, 0, 1],
"OUTD2": [0, 0, 1],
"OUTE2": [0, 0, 1],
"OUTF2": [1, 0, 1],
"TRIGGER": np.array([0, 6, 0], dtype=np.uint8),
}
)

Expand All @@ -252,19 +229,19 @@ def table_data_1() -> List[str]:


@pytest.fixture
def table_2() -> Dict[str, Iterable]:
table: Dict[str, Iterable] = dict(
REPEATS=[1, 0],
def table_2_np_arrays() -> Dict[str, UnpackedArray]:
table: Dict[str, UnpackedArray] = dict(
REPEATS=np.array([1, 0], dtype=np.uint32),
TRIGGER=["Immediate", "Immediate"],
POSITION=[-20, 2**31 - 1],
TIME1=[12, 2**32 - 1],
TIME2=[32, 1],
POSITION=np.array([-20, 2**31 - 1], dtype=np.int32),
TIME1=np.array([12, 2**32 - 1], dtype=np.uint32),
TIME2=np.array([32, 1], dtype=np.uint32),
)

table["OUTA1"] = [False, True]
table["OUTA2"] = [True, False]
table["OUTA1"] = np.array([0, 1], dtype=np.uint8)
table["OUTA2"] = np.array([1, 0], dtype=np.uint8)
for key in "BCDEF":
table[f"OUT{key}1"] = table[f"OUT{key}2"] = [False, False]
table[f"OUT{key}1"] = table[f"OUT{key}2"] = np.array([0, 0], dtype=np.uint8)

return table

Expand All @@ -284,25 +261,24 @@ def table_data_2() -> List[str]:


def test_table_packing_pack_length_mismatched(
table_1: OrderedDict[str, Iterable],
table_1_np_arrays: OrderedDict[str, UnpackedArray],
table_field_info: TableFieldInfo,
):
assert table_field_info.row_words

# Adjust one of the record lengths so it mismatches
field_info = table_field_info.fields[("OUTC1")]
assert field_info
table_1["OUTC1"] = np.array([1, 2, 3, 4, 5, 6, 7, 8])
table_1_np_arrays["OUTC1"] = np.array([1, 2, 3, 4, 5, 6, 7, 8])

with pytest.raises(AssertionError):
table_to_words(table_1, table_field_info)
table_to_words(table_1_np_arrays, table_field_info)


@pytest.mark.parametrize(
"table_fixture_name,table_data_fixture_name",
[
("table_1_not_in_panda_order", "table_data_1"),
("table_2", "table_data_2"),
("table_2_np_arrays", "table_data_2"),
("table_1_np_arrays", "table_data_1"),
],
)
Expand All @@ -312,12 +288,14 @@ def test_table_to_words_and_words_to_table(
table_field_info: TableFieldInfo,
request,
):
table: Dict[str, Iterable] = request.getfixturevalue(table_fixture_name)
table: Dict[str, UnpackedArray] = request.getfixturevalue(table_fixture_name)
table_data: List[str] = request.getfixturevalue(table_data_fixture_name)

output_data = table_to_words(table, table_field_info)
assert output_data == table_data
output_table = words_to_table(output_data, table_field_info)
output_table = words_to_table(
output_data, table_field_info, convert_enum_indices=True
)

# Test the correct keys are outputted
assert output_table.keys() == table.keys()
Expand All @@ -337,21 +315,37 @@ def test_table_packing_unpack(
table_data_1: List[str],
):
assert table_field_info.row_words
output_table = words_to_table(table_data_1, table_field_info)
output_table = words_to_table(
table_data_1, table_field_info, convert_enum_indices=True
)

actual: UnpackedArray
for field_name, actual in output_table.items():
expected = table_1_np_arrays[str(field_name)]
np.testing.assert_array_equal(actual, expected)


def test_table_packing_unpack_no_convert_enum(
table_1_np_arrays_int_enums: OrderedDict[str, UnpackedArray],
table_field_info: TableFieldInfo,
table_data_1: List[str],
):
assert table_field_info.row_words
output_table = words_to_table(table_data_1, table_field_info)

actual: UnpackedArray
for field_name, actual in output_table.items():
expected = table_1_np_arrays_int_enums[str(field_name)]
np.testing.assert_array_equal(actual, expected)


def test_table_packing_pack(
table_1: Dict[str, Iterable],
table_1_np_arrays: Dict[str, UnpackedArray],
table_field_info: TableFieldInfo,
table_data_1: List[str],
):
assert table_field_info.row_words
unpacked = table_to_words(table_1, table_field_info)
unpacked = table_to_words(table_1_np_arrays, table_field_info)

for actual, expected in zip(unpacked, table_data_1):
assert actual == expected

0 comments on commit 4cc1e30

Please sign in to comment.