Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

project dataclass - warn on extra kwargs #474

Merged
merged 4 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 71 additions & 9 deletions rubicon_ml/domain/project.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,80 @@
from __future__ import annotations
import datetime
import logging
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional

from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
if TYPE_CHECKING:
from rubicon_ml.domain.utils import TrainingMetadata

from rubicon_ml.domain.utils import TrainingMetadata, uuid
LOGGER = logging.getLogger()


@dataclass
@dataclass(init=False)
class Project:
"""A domain-level project.

Parameters
----------
name : str
The project's name.
created_at : datetime, optional
The date and time the project was created. Defaults to `None` and uses
`datetime.datetime.now` to generate a UTC timestamp. `created_at` should be
left as `None` to allow for automatic generation.
description : str, optional
A description of the project. Defaults to `None`.
github_url : str, optional
The URL of the GitHub repository associated with this project. Defaults to
`None`.
id : str, optional
The project's unique identifier. Defaults to `None` and uses `uuid.uuid4`
to generate a unique ID. `id` should be left as `None` to allow for automatic
generation.
training_metadata : rubicon_ml.domain.utils.TrainingMetadata, optional
Additional metadata pertaining to any data this project was trained on.
Defaults to `None`.
"""

name: str

id: str = field(default_factory=uuid.uuid4)
created_at: Optional[datetime.datetime] = None
description: Optional[str] = None
github_url: Optional[str] = None
training_metadata: Optional[TrainingMetadata] = None
created_at: datetime = field(default_factory=datetime.utcnow)
id: Optional[str] = None
training_metadata: Optional["TrainingMetadata"] = None

def __init__(
self,
name: str,
created_at: Optional[datetime.datetime] = None,
description: Optional[str] = None,
github_url: Optional[str] = None,
id: Optional[str] = None,
training_metadata: Optional["TrainingMetadata"] = None,
**kwargs,
):
"""Initialize this domain project."""

self.name = name

self.created_at = created_at
self.description = description
self.github_url = github_url
self.id = id
self.training_metadata = training_metadata

if self.created_at is None:
try: # `datetime.UTC` added & `datetime.utcnow` deprecated in Python 3.11
self.created_at = datetime.datetime.now(datetime.UTC)
except AttributeError:
self.created_at = datetime.datetime.utcnow()

if self.id is None:
self.id = str(uuid.uuid4())

if kwargs: # replaces `dataclass` behavior of erroring on unexpected kwargs
LOGGER.warning(
f"{self.__class__.__name__}.__init__() got an unexpected keyword "
f"argument(s): `{'`, `'.join([key for key in kwargs])}`"
)
24 changes: 24 additions & 0 deletions tests/unit/domain/test_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from unittest import mock

import pytest

from rubicon_ml.domain.project import Project


@pytest.mark.parametrize(
["domain_cls", "required_kwargs"],
[(Project, {"name": "test_domain_extra_kwargs"})],
)
def test_domain_extra_kwargs(domain_cls, required_kwargs):
with mock.patch(
f"rubicon_ml.domain.{domain_cls.__name__.lower()}.LOGGER.warning"
) as mock_logger_warning:
domain = domain_cls(extra="extra", **required_kwargs)

mock_logger_warning.assert_called_once_with(
f"{domain_cls.__name__}.__init__() got an unexpected keyword argument(s): `extra`",
)

assert "extra" not in domain.__dict__
for key, value in required_kwargs.items():
assert getattr(domain, key) == value
Loading