Skip to content

Commit

Permalink
feat: add CorrelationWriter (#96)
Browse files Browse the repository at this point in the history
Created CorrelationWriter class for the analyze portion of the pipeline.

I think there will eventually be a FeatureSetWriter that this should
probably inherit from, but I need the Correlation one now for Aerts.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a `CorrelationWriter` class for managing the writing of
correlation data to files with customizable paths and filenames.
- Added support for saving correlation data in both CSV and Excel
formats.

- **Bug Fixes**
- Implemented error handling for invalid correlation data, existing
files, and filename format validation.

- **Tests**
- Added a comprehensive suite of unit tests for the `CorrelationWriter`
class, covering various scenarios for saving correlation data.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Jermiah Joseph <[email protected]>
  • Loading branch information
strixy16 and jjjermiah authored Dec 18, 2024
1 parent c27d40f commit b241c42
Show file tree
Hide file tree
Showing 6 changed files with 564 additions and 136 deletions.
510 changes: 374 additions & 136 deletions pixi.lock

Large diffs are not rendered by default.

131 changes: 131 additions & 0 deletions src/readii/io/writers/correlation_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import ClassVar

from pandas import DataFrame

from readii.io.writers.base_writer import BaseWriter
from readii.utils import logger


class CorrelationWriterError(Exception):
"""Base exception for CorrelationWriter errors."""

pass


class CorrelationWriterIOError(CorrelationWriterError):
"""Raised when I/O operations fail."""

pass


class CorrelationWriterValidationError(CorrelationWriterError):
"""Raised when validation of writer configuration fails."""

pass

@dataclass
class CorrelationWriter(BaseWriter):
"""Class for managing file writing with customizable paths and filenames for plot figure files."""

overwrite: bool = field(
default=False,
metadata={
"help": "If True, allows overwriting existing files. If False, raises CorrelationWriterIOError."
},
)

# Make extensions immutable
VALID_EXTENSIONS: ClassVar[list[str]] = (
".csv",
".xlsx"
)

def __post_init__(self) -> None:
"""Validate writer configuration."""
super().__post_init__()

if not any(self.filename_format.endswith(ext) for ext in self.VALID_EXTENSIONS):
msg = f"Invalid filename format {self.filename_format}. Must end with one of {self.VALID_EXTENSIONS}."
raise CorrelationWriterValidationError(msg)

def save(self, correlation_df:DataFrame, **kwargs: str) -> Path:
"""Save the correlation dataframe to a .csv file.
Parameters
----------
correlation_df : DataFrame
The correlation dataframe to save.
**kwargs : str
Additional keyword arguments to pass to the filename format.
Returns
-------
Path
The path to the saved file.
Raises
------
CorrelationWriterIOError
If an error occurs during file writing.
CorrelationWriterValidationError
If the filename format is invalid.
"""
logger.debug("Saving.", kwargs=kwargs)

# Generate the output path
out_path = self.resolve_path(**kwargs)

# Check if the output path already exists
if out_path.exists():
if not self.overwrite:
msg = f"File {out_path} already exists. \nSet {self.__class__.__name__}.overwrite to True to overwrite."
raise CorrelationWriterIOError(msg)
else:
logger.warning(f"File {out_path} already exists. Overwriting.")

# Check if the correlation dataframe is a DataFrame
if not isinstance(correlation_df, DataFrame):
msg = f"Correlation dataframe must be a pandas DataFrame, got {type(correlation_df)}"
raise CorrelationWriterValidationError(msg)

# Check if the correlation dataframe is empty
if correlation_df.empty:
msg = "Correlation dataframe is empty"
raise CorrelationWriterValidationError(msg)

# Check that the columns and index of the correlation dataframe are the same
if not correlation_df.columns.equals(correlation_df.index):
msg = "Correlation dataframe columns and index are not the same"
raise CorrelationWriterValidationError(msg)

logger.debug("Saving correlation dataframe to file", out_path=out_path)
try:
match out_path.suffix:
case ".csv":
correlation_df.to_csv(out_path, index=True, index_label="")
case ".xlsx":
correlation_df.to_excel(out_path, index=True, index_label="")
case _:
msg = f"Invalid file extension {out_path.suffix}. Must be one of {self.VALID_EXTENSIONS}."
raise CorrelationWriterValidationError(msg)
except Exception as e:
msg = f"Error saving correlation dataframe to file {out_path}: {e}"
raise CorrelationWriterIOError(msg) from e
else:
logger.info("Correlation dataframe saved successfully.", out_path=out_path)
return out_path


if __name__ == "__main__": # pragma: no cover
from rich import print # noqa

plot_writer = CorrelationWriter(
root_directory=Path("TRASH", "correlation_writer_examples"),
filename_format="{DatasetName}_{VerticalFeatureType}_{HorizontalFeatureType}_{CorrelationType}_correlations.csv",
overwrite=True,
create_dirs=True
)

print(plot_writer)
File renamed without changes.
File renamed without changes.
59 changes: 59 additions & 0 deletions tests/io/writers/test_correlation_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest

import pandas as pd
import numpy as np
from pathlib import Path
from readii.analyze.correlation import getFeatureCorrelations
from readii.io.writers.correlation_writer import CorrelationWriter, CorrelationWriterValidationError, CorrelationWriterError, CorrelationWriterIOError # type: ignore

@pytest.fixture
def random_feature_correlations():
# Create a 10x10 matrix with random float values between 0 and 1
random_matrix = np.random.default_rng(seed=10).random((10,10))
# Convert to dataframe and name the columns feature1, feature2, etc.
random_df = pd.DataFrame(random_matrix, columns=[f"feature_{i+1}" for i in range(10)])
# Calculate correlation
return getFeatureCorrelations(random_df, random_df)


@pytest.fixture
def corr_writer(tmp_path):
"""Fixture for creating a CorrelationWriter instance."""
return CorrelationWriter(
root_directory=tmp_path,
filename_format="{CorrelationType}_correlation_matrix.csv",
overwrite=False,
create_dirs=True,
)

@pytest.mark.parametrize("correlation_df", ["not_a_correlation_df", 12345, pd.DataFrame()])
def test_save_invalid_correlation(corr_writer, correlation_df):
"""Test saving an invalid image."""
with pytest.raises(CorrelationWriterValidationError):
corr_writer.save(correlation_df, CorrelationType="Pearson")

@pytest.mark.parametrize("correlation_df", ["random_feature_correlations"])
def test_save_valid_correlation(corr_writer, request, correlation_df):
"""Test saving a valid correlation dataframe."""
correlation_df = request.getfixturevalue(correlation_df)
out_path = corr_writer.save(correlation_df, CorrelationType="Pearson")
assert out_path.exists()

def test_save_existing_file_without_overwrite(corr_writer, random_feature_correlations):
"""Test saving when file already exists and overwrite is False."""
corr_writer.save(random_feature_correlations, CorrelationType="Pearson")
with pytest.raises(CorrelationWriterIOError):
corr_writer.save(random_feature_correlations, CorrelationType="Pearson")

def test_save_existing_file_with_overwrite(corr_writer, random_feature_correlations):
"""Test saving when file already exists and overwrite is True."""
corr_writer.overwrite = True
corr_writer.save(random_feature_correlations, CorrelationType="Pearson")
assert corr_writer.save(random_feature_correlations, CorrelationType="Pearson").exists()

@pytest.mark.parametrize("filename_format", ["{CorrelationType}_correlation_matrix.csv", "{CorrelationType}_correlation_matrix.xlsx"])
def test_save_with_different_filename_formats(corr_writer, random_feature_correlations, filename_format):
"""Test saving with different filename formats."""
corr_writer.filename_format = filename_format
out_path = corr_writer.save(random_feature_correlations, CorrelationType="Pearson")
assert out_path.exists()
File renamed without changes.

0 comments on commit b241c42

Please sign in to comment.