From 5ad5dfa2a7a84038ec40fc3ab5b9a34d8d6b18c3 Mon Sep 17 00:00:00 2001 From: Matt Stone Date: Sun, 5 May 2024 15:51:30 -0400 Subject: [PATCH] feat: Add MetricWriter class wip wip wip refactor: reorder refactor: types fix: PR suggestions fix: typeerror fix: format values chore: typeguard refactor: make assertions private fix: Union and Optional to support 3.8,3.9 doc: update dataclass/Metric references fix: use Type to support 3.8 fix: use List to support 3.8 --- fgpyo/util/metric.py | 197 ++++++++++++++++++++++++++++++-- fgpyo/util/tests/test_metric.py | 182 +++++++++++++++++++++++++++++ 2 files changed, 371 insertions(+), 8 deletions(-) diff --git a/fgpyo/util/metric.py b/fgpyo/util/metric.py index 203090e5..a0ed7ab1 100644 --- a/fgpyo/util/metric.py +++ b/fgpyo/util/metric.py @@ -119,19 +119,25 @@ import dataclasses import sys from abc import ABC +from contextlib import AbstractContextManager +from csv import DictWriter from dataclasses import dataclass from enum import Enum from inspect import isclass from io import TextIOWrapper from pathlib import Path +from types import TracebackType from typing import Any from typing import Callable from typing import Dict from typing import Generic +from typing import Iterable from typing import Iterator from typing import List +from typing import Optional from typing import Type from typing import TypeVar +from typing import Union if sys.version_info >= (3, 10): from typing import TypeGuard @@ -479,6 +485,140 @@ def asdict(metric: Metric) -> Dict[str, Any]: ) +class MetricWriter(Generic[MetricType], AbstractContextManager): + _metric_class: Type[Metric] + _fieldnames: List[str] + _fout: TextIOWrapper + _writer: DictWriter + + def __init__( + self, + filename: Union[Path, str], + metric_class: Type[Metric], + append: bool = False, + delimiter: str = "\t", + include_fields: Optional[List[str]] = None, + exclude_fields: Optional[List[str]] = None, + ) -> None: + """ + Args: + path: Path to the file to write. + metric_class: Metric class. + append: If `True`, the file will be appended to. Otherwise, the specified file will be + overwritten. + delimiter: The output file delimiter. + include_fields: If specified, only the listed fieldnames will be included when writing + records to file. Fields will be written in the order provided. + May not be used together with `exclude_fields`. + exclude_fields: If specified, any listed fieldnames will be excluded when writing + records to file. + May not be used together with `include_fields`. + + Raises: + TypeError: If the provided metric class is not a dataclass- or attr-decorated + subclass of `Metric`. + AssertionError: If the provided filepath is not writable. + AssertionError: If `append=True` and the provided file is not readable. (When appending, + we check to ensure that the header matches the specified metric class. The file must + be readable to get the header.) + ValueError: If `append=True` and the provided file does not include a header. + ValueError: If `append=True` and the header of the provided file does not match the + specified metric class and the specified include/exclude fields. + """ + + filepath: Path = Path(filename) + ordered_fieldnames: List[str] = _validate_output_fieldnames( + metric_class=metric_class, + include_fields=include_fields, + exclude_fields=exclude_fields, + ) + + _assert_is_metric_class(metric_class) + io.assert_path_is_writeable(filepath) + if append: + io.assert_path_is_readable(filepath) + _assert_file_header_matches_metric( + path=filepath, + metric_class=metric_class, + ordered_fieldnames=ordered_fieldnames, + delimiter=delimiter, + ) + + self._metric_class = metric_class + self._fieldnames = ordered_fieldnames + self._fout = io.to_writer(filepath, append=append) + self._writer = DictWriter( + f=self._fout, + fieldnames=self._fieldnames, + delimiter=delimiter, + ) + + # If we aren't appending to an existing file, write the header before any rows + if not append: + self._writer.writeheader() + + def __enter__(self) -> "MetricWriter": + return self + + def __exit__( + self, + exc_type: Type[BaseException], + exc_value: BaseException, + traceback: TracebackType, + ) -> None: + self.close() + super().__exit__(exc_type, exc_value, traceback) + + def close(self) -> None: + """Close the underlying file handle.""" + self._fout.close() + + def write(self, metric: Metric) -> None: + """ + Write a single Metric instance to file. + + The Metric is converted to a dictionary and then written using the underlying + `csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or + `exclude_fields` arguments, the fields of the Metric are subset and/or reordered + accordingly before writing. + + Args: + metric: An instance of the specified Metric. + + Raises: + TypeError: If the provided `metric` is not an instance of the Metric class used to + parametrize the writer. + """ + if not isinstance(metric, self._metric_class): + raise TypeError(f"Must provide instances of {self._metric_class.__name__}") + + # Serialize the Metric to a dict for writing by the underlying `DictWriter` + row = asdict(metric) + + # Filter and/or re-order output fields if necessary + row = {fieldname: row[fieldname] for fieldname in self._fieldnames} + + # Format values + row = {fieldname: self._metric_class.format_value(val) for fieldname, val in row.items()} + + self._writer.writerow(row) + + def writeall(self, metrics: Iterable[Metric]) -> None: + """ + Write multiple Metric instances to file. + + Each Metric is converted to a dictionary and then written using the underlying + `csv.DictWriter`. If the `MetricWriter` was created using the `include_fields` or + `exclude_fields` arguments, the attributes of each Metric are subset and/or reordered + accordingly before writing. + + Args: + metrics: A sequence of instances of the specified Metric. + """ + for metric in metrics: + self.write(metric) + + def _get_fieldnames(metric_class: Type[Metric]) -> List[str]: """ Get the fieldnames of the specified metric class. @@ -499,10 +639,47 @@ def _get_fieldnames(metric_class: Type[Metric]) -> List[str]: assert False, "Unreachable" +def _validate_output_fieldnames( + metric_class: Type[MetricType], + include_fields: Optional[List[str]] = None, + exclude_fields: Optional[List[str]] = None, +) -> List[str]: + """ + Subset and/or re-order the Metric's fieldnames based on the specified include/exclude lists. + + * Only one of `include_fields` and `exclude_fields` may be specified. + * All fieldnames specified in `include_fields` must be fields on `metric_class`. If this + argument is specified, fields will be returned in the order they appear in the list. + * All fieldnames specified in `exclude_fields` must be fields on `metric_class`. (This is + technically unnecessary, but is a safeguard against passing an incorrect list.) + * If neither `include_fields` or `exclude_fields` are specified, return the `metric_class`'s + fieldnames, in the order they are defined on the `metric_class`. + + Raises: + ValueError: If both `include_fields` and `exclude_fields` are specified. + """ + + if include_fields is not None and exclude_fields is not None: + raise ValueError( + "Only one of `include_fields` and `exclude_fields` may be specified, not both." + ) + elif exclude_fields is not None: + _assert_fieldnames_are_metric_attributes(exclude_fields, metric_class) + output_fieldnames = [f for f in _get_fieldnames(metric_class) if f not in exclude_fields] + elif include_fields is not None: + _assert_fieldnames_are_metric_attributes(include_fields, metric_class) + output_fieldnames = include_fields + else: + output_fieldnames = _get_fieldnames(metric_class) + + return output_fieldnames + + def _assert_file_header_matches_metric( path: Path, metric_class: Type[MetricType], delimiter: str, + ordered_fieldnames: Optional[List[str]] = None, ) -> None: """ Check that the specified file has a header and its fields match those of the provided Metric. @@ -511,20 +688,24 @@ def _assert_file_header_matches_metric( path: A path to a `Metric` file. metric_class: The `Metric` class to validate against. delimiter: The delimiter to use when reading the header. + ordered_fieldnames: An optional ordering of the fieldnames in the header. Raises: ValueError: If the provided file does not include a header. - ValueError: If the header of the provided file does not match the provided Metric. + ValueError: If the header of the provided file does not match the provided Metric (or list + of ordered fieldnames, if provided). """ - # NB: _get_fieldnames() will validate that `metric_class` is a valid Metric class. - fieldnames: List[str] = _get_fieldnames(metric_class) + _assert_is_metric_class(metric_class) - header: MetricFileHeader with path.open("r") as fin: - try: - header = metric_class._read_header(fin, delimiter=delimiter) - except ValueError: - raise ValueError(f"Could not find a header in the provided file: {path}") + header: MetricFileHeader = metric_class._read_header(fin, delimiter=delimiter) + + if header is None: + raise ValueError(f"Could not find a header in the provided file: {path}") + + fieldnames: List[str] = ( + ordered_fieldnames if ordered_fieldnames is not None else _get_fieldnames(metric_class) + ) if header.fieldnames != fieldnames: raise ValueError( diff --git a/fgpyo/util/tests/test_metric.py b/fgpyo/util/tests/test_metric.py index 05be06e8..2c0be846 100644 --- a/fgpyo/util/tests/test_metric.py +++ b/fgpyo/util/tests/test_metric.py @@ -32,6 +32,7 @@ from fgpyo.util.metric import Metric from fgpyo.util.metric import _assert_fieldnames_are_metric_attributes from fgpyo.util.metric import _assert_file_header_matches_metric +from fgpyo.util.metric import MetricWriter from fgpyo.util.metric import _assert_is_metric_class from fgpyo.util.metric import _get_fieldnames from fgpyo.util.metric import _is_attrs_instance @@ -597,6 +598,187 @@ def test_read_header_can_read_picard(tmp_path: Path) -> None: assert header.fieldnames == ["SAMPLE", "FOO", "BAR"] +@dataclass +class FakeMetric(Metric["FakeMetric"]): + foo: str + bar: int + + +def test_writer(tmp_path: Path) -> None: + fpath = tmp_path / "test.txt" + + with MetricWriter(filename=fpath, append=False, metric_class=FakeMetric) as writer: + writer.write(FakeMetric(foo="abc", bar=1)) + writer.write(FakeMetric(foo="def", bar=2)) + + with fpath.open("r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_from_str(tmp_path: Path) -> None: + """Test that we can create a writer when `filename` is a `str`.""" + fpath = tmp_path / "test.txt" + + with MetricWriter(filename=str(fpath), append=False, metric_class=FakeMetric) as writer: + writer.write(FakeMetric(foo="abc", bar=1)) + + with fpath.open("r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_writeall(tmp_path: Path) -> None: + fpath = tmp_path / "test.txt" + + data = [ + FakeMetric(foo="abc", bar=1), + FakeMetric(foo="def", bar=2), + ] + with MetricWriter(filename=fpath, append=False, metric_class=FakeMetric) as writer: + writer.writeall(data) + + with fpath.open("r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_append(tmp_path: Path) -> None: + """Test that we can append to a file.""" + fpath = tmp_path / "test.txt" + + with fpath.open("w") as fout: + fout.write("foo\tbar\n") + + with MetricWriter(filename=fpath, append=True, metric_class=FakeMetric) as writer: + writer.write(FakeMetric(foo="abc", bar=1)) + writer.write(FakeMetric(foo="def", bar=2)) + + with fpath.open("r") as f: + assert next(f) == "foo\tbar\n" + assert next(f) == "abc\t1\n" + assert next(f) == "def\t2\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_append_raises_if_empty(tmp_path: Path) -> None: + """Test that we raise an error if we try to append to an empty file.""" + fpath = tmp_path / "test.txt" + fpath.touch() + + with pytest.raises(ValueError, match=f"File {fpath} did not contain a header line"): + with MetricWriter(filename=fpath, append=True, metric_class=FakeMetric) as writer: + writer.write(FakeMetric(foo="abc", bar=1)) + + +def test_writer_append_raises_if_no_header(tmp_path: Path) -> None: + """Test that we raise an error if we try to append to a file with no header.""" + fpath = tmp_path / "test.txt" + with fpath.open("w") as fout: + fout.write("abc\t1\n") + + with pytest.raises(ValueError, match="The provided file does not have the same field names"): + with MetricWriter(filename=fpath, append=True, metric_class=FakeMetric) as writer: + writer.write(FakeMetric(foo="abc", bar=1)) + + +def test_writer_append_raises_if_header_does_not_match(tmp_path: Path) -> None: + """ + Test that we raise an error if we try to append to a file whose header doesn't match our + dataclass. + """ + fpath = tmp_path / "test.txt" + + with fpath.open("w") as fout: + fout.write("foo\tbar\tbaz\n") + + with pytest.raises(ValueError, match="The provided file does not have the same field names"): + with MetricWriter(filename=fpath, append=True, metric_class=FakeMetric) as writer: + writer.write(FakeMetric(foo="abc", bar=1)) + + +def test_writer_include_fields(tmp_path: Path) -> None: + """Test that we can include only a subset of fields.""" + fpath = tmp_path / "test.txt" + + data = [ + FakeMetric(foo="abc", bar=1), + FakeMetric(foo="def", bar=2), + ] + with MetricWriter( + filename=fpath, + append=False, + metric_class=FakeMetric, + include_fields=["foo"], + ) as writer: + writer.writeall(data) + + with fpath.open("r") as f: + assert next(f) == "foo\n" + assert next(f) == "abc\n" + assert next(f) == "def\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_include_fields_reorders(tmp_path: Path) -> None: + """Test that we can reorder the output fields.""" + fpath = tmp_path / "test.txt" + + data = [ + FakeMetric(foo="abc", bar=1), + FakeMetric(foo="def", bar=2), + ] + with MetricWriter( + filename=fpath, + append=False, + metric_class=FakeMetric, + include_fields=["bar", "foo"], + ) as writer: + writer.writeall(data) + + with fpath.open("r") as f: + assert next(f) == "bar\tfoo\n" + assert next(f) == "1\tabc\n" + assert next(f) == "2\tdef\n" + with pytest.raises(StopIteration): + next(f) + + +def test_writer_exclude_fields(tmp_path: Path) -> None: + """Test that we can exclude fields from being written.""" + + fpath = tmp_path / "test.txt" + + data = [ + FakeMetric(foo="abc", bar=1), + FakeMetric(foo="def", bar=2), + ] + with MetricWriter( + filename=fpath, + append=False, + metric_class=FakeMetric, + exclude_fields=["bar"], + ) as writer: + writer.writeall(data) + + with fpath.open("r") as f: + assert next(f) == "foo\n" + assert next(f) == "abc\n" + assert next(f) == "def\n" + with pytest.raises(StopIteration): + next(f) + + @pytest.mark.parametrize("data_and_classes", (attr_data_and_classes, dataclasses_data_and_classes)) def test_get_fieldnames(data_and_classes: DataBuilder) -> None: """Test we can get the fieldnames of a metric."""