Skip to content

Commit

Permalink
feat: add text-generation as new model type (#205)
Browse files Browse the repository at this point in the history
* feat: add text-generation as new model type, handle the new model type, set schema fields as optional, edit test

* fix: ruff check

* feat: add optional schema fields to models definition (sdk)

* feat: set optional fields model schema (spark side)
  • Loading branch information
dtria91 authored Dec 11, 2024
1 parent 87488ee commit 40f0912
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 36 deletions.
57 changes: 43 additions & 14 deletions api/app/models/model_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ModelType(str, Enum):
REGRESSION = 'REGRESSION'
BINARY = 'BINARY'
MULTI_CLASS = 'MULTI_CLASS'
TEXT_GENERATION = 'TEXT_GENERATION'


class DataType(str, Enum):
Expand Down Expand Up @@ -93,20 +94,40 @@ class ModelIn(BaseModel, validate_assignment=True):
model_type: ModelType
data_type: DataType
granularity: Granularity
features: List[ColumnDefinition]
outputs: OutputType
target: ColumnDefinition
timestamp: ColumnDefinition
features: Optional[List[ColumnDefinition]] = None
outputs: Optional[OutputType] = None
target: Optional[ColumnDefinition] = None
timestamp: Optional[ColumnDefinition] = None
frameworks: Optional[str] = None
algorithm: Optional[str] = None

model_config = ConfigDict(
populate_by_name=True, alias_generator=to_camel, protected_namespaces=()
)

@model_validator(mode='after')
def validate_fields(self) -> Self:
checked_model_type = self.model_type
if checked_model_type == ModelType.TEXT_GENERATION:
if any([self.target, self.features, self.outputs, self.timestamp]):
raise ValueError(
f'target, features, outputs and timestamp must not be provided for a {checked_model_type}'
)
return self
if not self.features:
raise ValueError(f'features must be provided for a {checked_model_type}')
if not self.outputs:
raise ValueError(f'outputs must be provided for a {checked_model_type}')
if not self.target:
raise ValueError(f'target must be provided for a {checked_model_type}')
if not self.timestamp:
raise ValueError(f'timestamp must be provided for a {checked_model_type}')

return self

@model_validator(mode='after')
def validate_target(self) -> Self:
checked_model_type: ModelType = self.model_type
checked_model_type = self.model_type
match checked_model_type:
case ModelType.BINARY:
if not is_number(self.target.type):
Expand All @@ -126,12 +147,14 @@ def validate_target(self) -> Self:
f'target must be a number for a {checked_model_type}, has been provided [{self.target}]'
)
return self
case ModelType.TEXT_GENERATION:
return self
case _:
raise ValueError('not supported type for model_type')

@model_validator(mode='after')
def validate_outputs(self) -> Self:
checked_model_type: ModelType = self.model_type
checked_model_type = self.model_type
match checked_model_type:
case ModelType.BINARY:
if not is_number(self.outputs.prediction.type):
Expand Down Expand Up @@ -169,11 +192,15 @@ def validate_outputs(self) -> Self:
f'prediction_proba must be None for a {checked_model_type}, has been provided [{self.outputs.prediction_proba}]'
)
return self
case ModelType.TEXT_GENERATION:
return self
case _:
raise ValueError('not supported type for model_type')

@model_validator(mode='after')
def timestamp_must_be_datetime(self) -> Self:
if self.model_type == ModelType.TEXT_GENERATION:
return self
if not self.timestamp.type == SupportedTypes.datetime:
raise ValueError('timestamp must be a datetime')
return self
Expand All @@ -187,10 +214,12 @@ def to_model(self) -> Model:
model_type=self.model_type.value,
data_type=self.data_type.value,
granularity=self.granularity.value,
features=[feature.to_dict() for feature in self.features],
outputs=self.outputs.to_dict(),
target=self.target.to_dict(),
timestamp=self.timestamp.to_dict(),
features=[feature.to_dict() for feature in self.features]
if self.features
else None,
outputs=self.outputs.to_dict() if self.outputs else None,
target=self.target.to_dict() if self.target else None,
timestamp=self.timestamp.to_dict() if self.timestamp else None,
frameworks=self.frameworks,
algorithm=self.algorithm,
created_at=now,
Expand All @@ -205,10 +234,10 @@ class ModelOut(BaseModel):
model_type: ModelType
data_type: DataType
granularity: Granularity
features: List[ColumnDefinition]
outputs: OutputType
target: ColumnDefinition
timestamp: ColumnDefinition
features: Optional[List[ColumnDefinition]]
outputs: Optional[OutputType]
target: Optional[ColumnDefinition]
timestamp: Optional[ColumnDefinition]
frameworks: Optional[str]
algorithm: Optional[str]
created_at: str
Expand Down
20 changes: 12 additions & 8 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def get_sample_model(
model_type: str = ModelType.BINARY.value,
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[Dict] = [
features: Optional[List[Dict]] = [
{'name': 'feature1', 'type': 'string', 'fieldType': 'categorical'}
],
outputs: Dict = {
outputs: Optional[Dict] = {
'prediction': {'name': 'pred1', 'type': 'int', 'fieldType': 'numerical'},
'prediction_proba': {
'name': 'prob1',
Expand All @@ -45,8 +45,12 @@ def get_sample_model(
},
'output': [{'name': 'output1', 'type': 'string', 'fieldType': 'categorical'}],
},
target: Dict = {'name': 'target1', 'type': 'string', 'fieldType': 'categorical'},
timestamp: Dict = {
target: Optional[Dict] = {
'name': 'target1',
'type': 'string',
'fieldType': 'categorical',
},
timestamp: Optional[Dict] = {
'name': 'timestamp',
'type': 'datetime',
'fieldType': 'datetime',
Expand Down Expand Up @@ -91,14 +95,14 @@ def get_sample_model_in(
model_type: str = ModelType.BINARY.value,
data_type: str = DataType.TEXT.value,
granularity: str = Granularity.DAY.value,
features: List[ColumnDefinition] = [
features: Optional[List[ColumnDefinition]] = [
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
outputs: OutputType = OutputType(
outputs: Optional[OutputType] = OutputType(
prediction=ColumnDefinition(
name='pred1', type=SupportedTypes.int, field_type=FieldType.numerical
),
Expand All @@ -113,10 +117,10 @@ def get_sample_model_in(
)
],
),
target: ColumnDefinition = ColumnDefinition(
target: Optional[ColumnDefinition] = ColumnDefinition(
name='target1', type=SupportedTypes.int, field_type=FieldType.numerical
),
timestamp: ColumnDefinition = ColumnDefinition(
timestamp: Optional[ColumnDefinition] = ColumnDefinition(
name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime
),
frameworks: Optional[str] = None,
Expand Down
58 changes: 58 additions & 0 deletions api/tests/commons/modelin_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,64 @@ def get_model_sample_wrong(fail_fields: List[str], model_type: ModelType):
name='timestamp', type=SupportedTypes.datetime, field_type=FieldType.datetime
)

if model_type == ModelType.TEXT_GENERATION:
if 'features' in fail_fields:
features = [
ColumnDefinition(
name='feature1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
]
else:
features = None

if 'outputs' in fail_fields:
outputs = OutputType(
prediction=prediction,
prediction_proba=prediction_proba,
output=[
ColumnDefinition(
name='output1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
],
)
else:
outputs = None

if 'target' in fail_fields:
target = ColumnDefinition(
name='target1',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
else:
target = None

if 'timestamp' in fail_fields:
timestamp = ColumnDefinition(
name='timestamp',
type=SupportedTypes.string,
field_type=FieldType.categorical,
)
else:
timestamp = None

return {
'name': 'text_generation_model',
'model_type': model_type,
'data_type': DataType.TEXT,
'granularity': Granularity.DAY,
'features': features,
'outputs': outputs,
'target': target,
'timestamp': timestamp,
'frameworks': None,
'algorithm': None,
}

if 'outputs.prediction' in fail_fields:
if model_type == ModelType.BINARY:
prediction = ColumnDefinition(
Expand Down
23 changes: 22 additions & 1 deletion api/tests/services/model_service_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from app.db.dao.reference_dataset_dao import ReferenceDatasetDAO
from app.models.alert_dto import AnomalyType
from app.models.exceptions import ModelError, ModelNotFoundError
from app.models.model_dto import ModelOut
from app.models.model_dto import ModelOut, ModelType
from app.models.model_order import OrderType
from app.services.model_service import ModelService
from tests.commons import db_mock
Expand Down Expand Up @@ -42,6 +42,27 @@ def test_create_model_ok(self):

assert res == ModelOut.from_model(model)

def test_create_model_with_empty_schema_ok(self):
model = db_mock.get_sample_model(
model_type=ModelType.TEXT_GENERATION,
features=None,
target=None,
outputs=None,
timestamp=None,
)
self.model_dao.insert = MagicMock(return_value=model)
model_in = db_mock.get_sample_model_in(
model_type=ModelType.TEXT_GENERATION,
features=None,
target=None,
outputs=None,
timestamp=None,
)
res = self.model_service.create_model(model_in)
self.model_dao.insert.assert_called_once()

assert res == ModelOut.from_model(model)

def test_get_model_by_uuid_ok(self):
model = db_mock.get_sample_model()
reference_dataset = db_mock.get_sample_reference_dataset(model_uuid=model.uuid)
Expand Down
34 changes: 33 additions & 1 deletion api/tests/validation/model_type_validator_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import ValidationError
import pytest

from app.models.model_dto import ModelIn, ModelType
from app.models.model_dto import DataType, Granularity, ModelIn, ModelType
from tests.commons.modelin_factory import get_model_sample_wrong


Expand Down Expand Up @@ -108,3 +108,35 @@ def test_prediction_proba_for_regression():
assert 'prediction_proba must be None for a ModelType.REGRESSION' in str(
excinfo.value
)


def test_text_generation_invalid_fields_provided():
"""Tests that TEXT_GENERATION fails if features, outputs, target, or timestamp are provided."""
with pytest.raises(ValidationError) as excinfo:
model_data = get_model_sample_wrong(
fail_fields=['features', 'outputs', 'target', 'timestamp'],
model_type=ModelType.TEXT_GENERATION,
)
ModelIn.model_validate(ModelIn(**model_data))
assert (
'target, features, outputs and timestamp must not be provided for a ModelType.TEXT_GENERATION'
in str(excinfo.value)
)


def test_text_generation_valid():
"""Tests that TEXT_GENERATION passes validation with no schema fields."""
model_data = {
'name': 'text_generation_model',
'model_type': ModelType.TEXT_GENERATION,
'data_type': DataType.TEXT,
'granularity': Granularity.DAY,
'frameworks': 'transformer',
'algorithm': 'gpt-like',
}
model = ModelIn.model_validate(ModelIn(**model_data))
assert model.model_type == ModelType.TEXT_GENERATION
assert model.features is None
assert model.outputs is None
assert model.target is None
assert model.timestamp is None
8 changes: 4 additions & 4 deletions sdk/radicalbit_platform_sdk/apis/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ def data_type(self) -> DataType:
def granularity(self) -> Granularity:
return self.__granularity

def features(self) -> List[ColumnDefinition]:
def features(self) -> Optional[List[ColumnDefinition]]:
return self.__features

def target(self) -> ColumnDefinition:
def target(self) -> Optional[ColumnDefinition]:
return self.__target

def timestamp(self) -> ColumnDefinition:
def timestamp(self) -> Optional[ColumnDefinition]:
return self.__timestamp

def outputs(self) -> OutputType:
def outputs(self) -> Optional[OutputType]:
return self.__outputs

def frameworks(self) -> Optional[str]:
Expand Down
8 changes: 4 additions & 4 deletions sdk/radicalbit_platform_sdk/models/model_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ class BaseModelDefinition(BaseModel):
model_type: ModelType
data_type: DataType
granularity: Granularity
features: List[ColumnDefinition]
outputs: OutputType
target: ColumnDefinition
timestamp: ColumnDefinition
features: Optional[List[ColumnDefinition]] = None
outputs: Optional[OutputType] = None
target: Optional[ColumnDefinition] = None
timestamp: Optional[ColumnDefinition] = None
frameworks: Optional[str] = None
algorithm: Optional[str] = None

Expand Down
1 change: 1 addition & 0 deletions sdk/radicalbit_platform_sdk/models/model_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class ModelType(str, Enum):
REGRESSION = 'REGRESSION'
BINARY = 'BINARY'
MULTI_CLASS = 'MULTI_CLASS'
TEXT_GENERATION = 'TEXT_GENERATION'
Loading

0 comments on commit 40f0912

Please sign in to comment.