Skip to content

Commit

Permalink
fix implementations in trace_metadata, and add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vreuter committed Jan 24, 2025
1 parent 9af764f commit 23a656e
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 36 deletions.
2 changes: 1 addition & 1 deletion looptrace/Tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def get_locus_times(t: TimepointFrom0) -> Result[Times, TimepointFrom0]:
case option.Option(tag="none", none=_):
raise ValueError(f"Failed to lookup group times for trace group {trace_group}")
case option.Option(tag="some", some=group_regional_times):
match traverse_through_either(lambda rt: get_locus_times(rt).map(lambda lts: (rt, lts)))(group_regional_times):
match traverse_through_either(lambda rt: get_locus_times(rt).map(lambda lts: (rt, lts)))(group_regional_times.get):
case result.Result(tag="error", error=unfound_regional_times):
raise ValueError(f"Failed to find locus times for {len(unfound_regional_times)} regional timepoint(s) in trace group {trace_group}: {unfound_regional_times}")
case result.Result(tag="ok", ok=pairs):
Expand Down
85 changes: 50 additions & 35 deletions looptrace/trace_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,36 +78,42 @@ def trace_group_option_to_string(maybe_name: Option[TraceGroupName]) -> str:


@curry_flip(1)
def _is_all_of_type(xs: Iterable[_A], t: _T) -> bool:
return all(isinstance(x, t) for x in xs)
def _check_all_of_type(xs: Iterable[_A], t: _T) -> None:
for x in tee(xs, 1)[0]:
if not isinstance(x, t):
raise TypeError(f"First item not of type {t.__name__} is of type {type(x).__name__}")


def _is_homogeneous(items: Iterable[_A], t: Optional[_T] = None) -> bool:
xs = lambda: tee(items, 1)
def _check_homogeneous(items: Iterable[_A], t: Optional[_T] = None) -> None:
xs = tee(items, 1)[0] # Caution for if input is an iterator.
if t is None:
try:
t = type(next(xs))
except StopIteration:
return True
return _is_all_of_type(t)(xs)
# Empty collection is trivially homogeneous.
return
_check_all_of_type(t)(xs)


@attrs.define(frozen=True)
class TraceGroupTimes:
get = attrs.field(validator=[
attrs.validators.instance_of(frozenset),
attrs.validators.min_len(2),
lambda _1, _2, times: _is_homogeneous(times, TimepointFrom0),
lambda _1, _2, times: _check_homogeneous(times, TimepointFrom0),
]) # type: frozenset[TimepointFrom0]

def __iter__(self) -> Iterable[TimepointFrom0]:
return iter(self.get)

@classmethod
def from_list(cls, times: list[TimepointFrom0]) -> Result["TraceGroupTimes", str]:
match list(find_counts_of_repeats(times)):
case []:
@wrap_error_message("Trace group from times list")
@wrap_exception((TypeError, ValueError))
def safe_build(ts: list[TimepointFrom0]) -> Result["TraceGroupTimes", str]:
return cls(set(ts))
return cls(frozenset(ts))

return safe_build(times)
case repeated:
Expand All @@ -120,25 +126,30 @@ class TraceGroup:
times = attrs.field(validator=attrs.validators.instance_of(TraceGroupTimes)) # type: TraceGroupTimes


@dataclass(frozen=True, kw_only=True)
def _validate_trace_groups_content(groups: Iterable[TraceGroup]) -> None:
repeat_names = list(find_counts_of_repeats(g.name for g in groups))
if len(repeat_names) != 0:
raise ValueError(f"Repeated name(s) among trace groups; counts: {repeat_names}")
repeat_times = list(find_counts_of_repeats(t for g in groups for t in g.times))
if len(repeat_times) > 0:
raise ValueError(f"Repeated time(s) among trace groups; counts: {repeat_times}")


@attrs.define(frozen=True)
class PotentialTraceMetadata:
groups: frozenset[TraceGroup]

def __post_init__(self) -> None:
if not isinstance(self.groups, set):
raise TypeError(f"Wrapped object for potential trace metadata must be set, not {type(self.groups).__name__}")
if len(self.groups) == 0:
raise ValueError("Empty trace groups for trace metadata")
repeat_names = list(find_counts_of_repeats(g.name for g in self.groups))
if len(repeat_names) != 0:
raise ValueError(f"Repeated name(s) among trace groups; counts: {repeat_names}")
repeat_times = list(find_counts_of_repeats(t for g in self.groups for t in g.times))
if len(repeat_times) > 0:
raise ValueError(f"Repeated time(s) among trace groups; counts: {repeat_times}")

groups = attrs.field(validator=[
attrs.validators.instance_of(frozenset),
attrs.validators.min_len(1),
lambda _1, _2, values: _check_homogeneous(values, TraceGroup),
lambda _1, _2, values: _validate_trace_groups_content(values),
]) # type: frozenset[TraceGroup]
_times_by_group = attrs.field(init=False) # type: Mapping[TraceGroupName, frozenset[TimepointFrom0]]
_trace_group_name_by_times = attrs.field(init=False) # type: Mapping[frozenset[TimepointFrom0], TraceGroupName]

def __attrs_post_init__(self) -> None:
# Here, finally, we establish the data structures to back our own code's desired queries.
self._times_by_group: dict[TraceGroupName, frozenset[TimepointFrom0]] = {g.name: g.times for g in self.groups}
self._trace_group_name_by_times: Mapping[frozenset[TimepointFrom0], TraceGroupName] = {}
object.__setattr__(self, "_times_by_group", {g.name: g.times for g in self.groups})
object.__setattr__(self, "_trace_group_name_by_times", {})
for g in self.groups:
ts = g.times
try:
Expand All @@ -148,40 +159,44 @@ def __post_init__(self) -> None:
else:
raise ValueError(f"Already mapped times {ts} to name {name}; tried to re-map to {g.name}")

def get_group_times(self, group: TraceGroupName) -> Option[set[TimepointFrom0]]:
def get_group_times(self, group: TraceGroupName) -> Option[TraceGroupTimes]:
if not isinstance(group, TraceGroupName):
raise TypeError(f"Query isn't a {TraceGroupName.__name__}, but a {type(group).__name__}")
return Option.of_optional(self._times_by_group.get(group))

@classmethod
def from_mapping(cls, m: Mapping[str, object]) -> Result["PotentialTraceMetadata", list[str]]:
def proc1(key: str, value: object) -> Result[TraceGroup, Errors]:
name_result = read_trace_group_name(key)
group_result = parse_trace_group_times(value)
match name_result, group_result:
name_result: Result[TraceGroupName, str] = read_trace_group_name(key)
times_result: Result[TraceGroupTimes, str] = parse_trace_group_times(value)
match name_result, times_result:
case result.Result(tag="error", error=err_name), result.Result(tag="error", error=err_times):
return Result.error(Seq.of(err_name, err_times))
return Result.Error(Seq.of(err_name, err_times))
case _, _:
return name_result.map2(group_result, TraceGroup)
return name_result\
.map2(times_result, lambda name, times: TraceGroup(name=name, times=times))\
.map_error(Seq.of)

def combine(state: Result[set[TraceGroup], Errors], new_result: Result[TraceGroup, Errors]) -> Result[set[TraceGroup], Errors]:
match state, new_result:
case result.Result(tag="error", error=old_messages), result.Result(tag="error", error=new_messages):
return Result(Seq.of_iterable(concat(old_messages, new_messages)))
return Result.Error(Seq.of_iterable(concat(old_messages, new_messages)))
case _, _:
return state.map2(new_result, lambda groups, new_group: groups + new_group)
return state.map2(new_result, lambda groups, new_group: {new_group, *groups})

def step(acc: Result[set[TraceGroup], Errors], kv: tuple[str, object]) -> Result[set[TraceGroup, Errors]]:
return combine(state=acc, new_result=proc1(*kv))

return Seq.of_iterable(m.items())\
.fold(step, Result.Ok(set()))\
.map(compose(frozenset, cls))\
.map_error(lambda es: es.to_list())
.map_error(lambda errors: errors.to_list())


def _check_homogeneous_list(t: _T, *, attr_name: str, xs: Any) -> None:
match xs:
case list():
if not _is_all_of_type(t)(xs):
if not _check_all_of_type(t)(xs):
raise TypeError(f"Not all values for attribute '{attr_name}' are of type {t.__name__}")
case _:
raise TypeError(f"Value for attribute {attr_name} isn't list, but {type(xs).__name__}")
Expand Down
81 changes: 81 additions & 0 deletions tests/test_potential_trace_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Tests for the PotentialTraceMetadata data type, which stores the merge rules for tracing"""

from typing import Callable, Iterable

from expression import Option, result
import pytest

from gertils.types import TimepointFrom0
from looptrace.trace_metadata import PotentialTraceMetadata, TraceGroup, TraceGroupName, TraceGroupTimes


def build_times(ts: Iterable[int]) -> TraceGroupTimes:
return TraceGroupTimes(frozenset(map(TimepointFrom0, ts)))


GROUP_NAME: TraceGroupName = TraceGroupName("dummy")

TIMES: TraceGroupTimes = build_times([1, 2])

DUMMY_GROUP: TraceGroup = TraceGroup(name=GROUP_NAME, times=TIMES)


@pytest.mark.parametrize("wrap", [set, frozenset])
def test_groups_must_be_frozenset(wrap):
groups = wrap([DUMMY_GROUP])
if wrap == set:
with pytest.raises(TypeError):
PotentialTraceMetadata(groups)
elif wrap == frozenset:
md: PotentialTraceMetadata = PotentialTraceMetadata(groups)
assert list(md.groups) == [DUMMY_GROUP]
else:
pytest.fail(f"Unexpected wrapper: {wrap}")


def test_groups_must_be_nonempty():
with pytest.raises(ValueError):
PotentialTraceMetadata(frozenset())


def test_group_name_repetition_is_prohibited():
with pytest.raises(ValueError) as error_context:
PotentialTraceMetadata(frozenset([
TraceGroup(name=GROUP_NAME, times=build_times([1, 2])),
TraceGroup(name=GROUP_NAME, times=build_times([3, 4])),
]))
assert "Repeated name(s) among trace groups" in str(error_context)


def test_timepoint_repetition_is_prohibited():
with pytest.raises(ValueError) as error_context:
PotentialTraceMetadata(frozenset([
TraceGroup(name=TraceGroupName("a"), times=build_times([1, 2])),
TraceGroup(name=TraceGroupName("b"), times=build_times([2, 3])),
]))
assert "Repeated time(s) among trace groups" in str(error_context)


@pytest.mark.parametrize(
["arg", "expected"],
[
(TraceGroupName(GROUP_NAME.get.upper()), Option.Nothing()),
(GROUP_NAME, Option.Some(TIMES))
],
)
def test_get_group_times_is_correct(arg, expected):
md: PotentialTraceMetadata = PotentialTraceMetadata(frozenset([DUMMY_GROUP]))
assert md.get_group_times(arg) == expected


def test_roundtrip_through_mapping():
lift_times: Callable[[list[int]], list[TimepointFrom0]] = lambda ts: [TimepointFrom0(t) for t in ts]
a_times: list[int] = [1, 2]
b_times: list[int] = [3, 4]
match PotentialTraceMetadata.from_mapping({"A": a_times, "B": b_times}):
case result.Result(tag="ok", ok=md):
groups: list[TraceGroup] = list(sorted(md.groups, key=lambda g: g.name))
assert [g.name.get for g in groups] == ["A", "B"]
assert [list(sorted(g.times)) for g in groups] == [lift_times(a_times), lift_times(b_times)]
case result.Result(tag="error", error=messages):
pytest.fail(f"{len(messages)} problem(s) building potential trace metadata: {messages}")

0 comments on commit 23a656e

Please sign in to comment.