Skip to content

Commit

Permalink
feat: Add MetricWriter class
Browse files Browse the repository at this point in the history
wip

wip

wip

refactor: reorder

refactor: types

fix: PR suggestions

fix: typeerror

fix: format values

chore: typeguard

refactor: make assertions private

fix: Union and Optional to support 3.8,3.9

doc: update dataclass/Metric references

fix: use Type to support 3.8

fix: use List to support 3.8
  • Loading branch information
msto committed Oct 17, 2024
1 parent d25ff12 commit 5ad5dfa
Show file tree
Hide file tree
Showing 2 changed files with 371 additions and 8 deletions.
197 changes: 189 additions & 8 deletions fgpyo/util/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,25 @@
import dataclasses
import sys
from abc import ABC
from contextlib import AbstractContextManager
from csv import DictWriter
from dataclasses import dataclass
from enum import Enum
from inspect import isclass
from io import TextIOWrapper
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union

if sys.version_info >= (3, 10):
from typing import TypeGuard
Expand Down Expand Up @@ -479,6 +485,140 @@ def asdict(metric: Metric) -> Dict[str, Any]:
)


class MetricWriter(Generic[MetricType], AbstractContextManager):
_metric_class: Type[Metric]
_fieldnames: List[str]
_fout: TextIOWrapper
_writer: DictWriter

def __init__(
self,
filename: Union[Path, str],
metric_class: Type[Metric],
append: bool = False,
delimiter: str = "\t",
include_fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
) -> None:
"""
Args:
path: Path to the file to write.
metric_class: Metric class.
append: If `True`, the file will be appended to. Otherwise, the specified file will be
overwritten.
delimiter: The output file delimiter.
include_fields: If specified, only the listed fieldnames will be included when writing
records to file. Fields will be written in the order provided.
May not be used together with `exclude_fields`.
exclude_fields: If specified, any listed fieldnames will be excluded when writing
records to file.
May not be used together with `include_fields`.
Raises:
TypeError: If the provided metric class is not a dataclass- or attr-decorated
subclass of `Metric`.
AssertionError: If the provided filepath is not writable.
AssertionError: If `append=True` and the provided file is not readable. (When appending,
we check to ensure that the header matches the specified metric class. The file must
be readable to get the header.)
ValueError: If `append=True` and the provided file does not include a header.
ValueError: If `append=True` and the header of the provided file does not match the
specified metric class and the specified include/exclude fields.
"""

filepath: Path = Path(filename)
ordered_fieldnames: List[str] = _validate_output_fieldnames(
metric_class=metric_class,
include_fields=include_fields,
exclude_fields=exclude_fields,
)

_assert_is_metric_class(metric_class)
io.assert_path_is_writeable(filepath)
if append:
io.assert_path_is_readable(filepath)
_assert_file_header_matches_metric(
path=filepath,
metric_class=metric_class,
ordered_fieldnames=ordered_fieldnames,
delimiter=delimiter,
)

self._metric_class = metric_class
self._fieldnames = ordered_fieldnames
self._fout = io.to_writer(filepath, append=append)
self._writer = DictWriter(
f=self._fout,
fieldnames=self._fieldnames,
delimiter=delimiter,
)

# If we aren't appending to an existing file, write the header before any rows
if not append:
self._writer.writeheader()

def __enter__(self) -> "MetricWriter":
return self

def __exit__(
self,
exc_type: Type[BaseException],
exc_value: BaseException,
traceback: TracebackType,
) -> None:
self.close()
super().__exit__(exc_type, exc_value, traceback)

def close(self) -> None:
"""Close the underlying file handle."""
self._fout.close()

def write(self, metric: Metric) -> None:
"""
Write a single Metric instance to file.
The Metric is converted to a dictionary and then written using the underlying
`csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or
`exclude_fields` arguments, the fields of the Metric are subset and/or reordered
accordingly before writing.
Args:
metric: An instance of the specified Metric.
Raises:
TypeError: If the provided `metric` is not an instance of the Metric class used to
parametrize the writer.
"""
if not isinstance(metric, self._metric_class):
raise TypeError(f"Must provide instances of {self._metric_class.__name__}")

# Serialize the Metric to a dict for writing by the underlying `DictWriter`
row = asdict(metric)

# Filter and/or re-order output fields if necessary
row = {fieldname: row[fieldname] for fieldname in self._fieldnames}

# Format values
row = {fieldname: self._metric_class.format_value(val) for fieldname, val in row.items()}

self._writer.writerow(row)

def writeall(self, metrics: Iterable[Metric]) -> None:
"""
Write multiple Metric instances to file.
Each Metric is converted to a dictionary and then written using the underlying
`csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or
`exclude_fields` arguments, the attributes of each Metric are subset and/or reordered
accordingly before writing.
Args:
metrics: A sequence of instances of the specified Metric.
"""
for metric in metrics:
self.write(metric)


def _get_fieldnames(metric_class: Type[Metric]) -> List[str]:
"""
Get the fieldnames of the specified metric class.
Expand All @@ -499,10 +639,47 @@ def _get_fieldnames(metric_class: Type[Metric]) -> List[str]:
assert False, "Unreachable"


def _validate_output_fieldnames(
metric_class: Type[MetricType],
include_fields: Optional[List[str]] = None,
exclude_fields: Optional[List[str]] = None,
) -> List[str]:
"""
Subset and/or re-order the Metric's fieldnames based on the specified include/exclude lists.
* Only one of `include_fields` and `exclude_fields` may be specified.
* All fieldnames specified in `include_fields` must be fields on `metric_class`. If this
argument is specified, fields will be returned in the order they appear in the list.
* All fieldnames specified in `exclude_fields` must be fields on `metric_class`. (This is
technically unnecessary, but is a safeguard against passing an incorrect list.)
* If neither `include_fields` or `exclude_fields` are specified, return the `metric_class`'s
fieldnames, in the order they are defined on the `metric_class`.
Raises:
ValueError: If both `include_fields` and `exclude_fields` are specified.
"""

if include_fields is not None and exclude_fields is not None:
raise ValueError(
"Only one of `include_fields` and `exclude_fields` may be specified, not both."
)
elif exclude_fields is not None:
_assert_fieldnames_are_metric_attributes(exclude_fields, metric_class)
output_fieldnames = [f for f in _get_fieldnames(metric_class) if f not in exclude_fields]
elif include_fields is not None:
_assert_fieldnames_are_metric_attributes(include_fields, metric_class)
output_fieldnames = include_fields
else:
output_fieldnames = _get_fieldnames(metric_class)

return output_fieldnames


def _assert_file_header_matches_metric(
path: Path,
metric_class: Type[MetricType],
delimiter: str,
ordered_fieldnames: Optional[List[str]] = None,
) -> None:
"""
Check that the specified file has a header and its fields match those of the provided Metric.
Expand All @@ -511,20 +688,24 @@ def _assert_file_header_matches_metric(
path: A path to a `Metric` file.
metric_class: The `Metric` class to validate against.
delimiter: The delimiter to use when reading the header.
ordered_fieldnames: An optional ordering of the fieldnames in the header.
Raises:
ValueError: If the provided file does not include a header.
ValueError: If the header of the provided file does not match the provided Metric.
ValueError: If the header of the provided file does not match the provided Metric (or list
of ordered fieldnames, if provided).
"""
# NB: _get_fieldnames() will validate that `metric_class` is a valid Metric class.
fieldnames: List[str] = _get_fieldnames(metric_class)
_assert_is_metric_class(metric_class)

header: MetricFileHeader
with path.open("r") as fin:
try:
header = metric_class._read_header(fin, delimiter=delimiter)
except ValueError:
raise ValueError(f"Could not find a header in the provided file: {path}")
header: MetricFileHeader = metric_class._read_header(fin, delimiter=delimiter)

if header is None:
raise ValueError(f"Could not find a header in the provided file: {path}")

fieldnames: List[str] = (
ordered_fieldnames if ordered_fieldnames is not None else _get_fieldnames(metric_class)
)

if header.fieldnames != fieldnames:
raise ValueError(
Expand Down
Loading

0 comments on commit 5ad5dfa

Please sign in to comment.