Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] 간단한 테스트용 DNN Model 구현 #28

Merged
merged 19 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions CATS/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn.functional as F
from sklearn.metrics import *
from tensorflow.keras.callbacks import Callback
from tensorflow.python.keras.callbacks import CallbackList
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

Expand All @@ -17,15 +18,12 @@
build_input_features, create_embedding_matrix)
from ..layers import PredictionLayer

from tensorflow.python.keras.callbacks import CallbackList


class BaseModel(nn.Module):
def __init__(
self,
linear_feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]],
dnn_feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]],
l2_reg_linear: float = 1e-5,
l2_reg_embedding: float = 1e-5,
init_std: float = 0.0001,
seed: int = 1024,
Expand Down Expand Up @@ -59,16 +57,11 @@ def __init__(
dnn_feature_columns, init_std, sparse=False, device=device
)

self.linear_model = nn.Linear(
self._compute_input_dim(linear_feature_columns), 1, bias=False
).to(device)

self.regularization_weight = []

self.add_regularization_weight(
self.embedding_dict.parameters(), l2=l2_reg_embedding
)
self.add_regularization_weight(self.linear_model.parameters(), l2=l2_reg_linear)
f-lab-owen marked this conversation as resolved.
Show resolved Hide resolved

self.out = PredictionLayer(task)
self.to(device)
Expand Down Expand Up @@ -230,14 +223,6 @@ def fit(
for name in self.metrics:
eval_str += " - " + name + ": {0: .4f}".format(epoch_logs[name])

if do_validation:
for name in self.metrics:
eval_str += (
" - "
+ "val_"
+ name
+ ": {0: .4f}".format(epoch_logs["val_" + name])
)
logging.info(eval_str)
callbacks.on_epoch_end(epoch, epoch_logs)
if self.stop_training:
Expand Down Expand Up @@ -289,9 +274,7 @@ def _compute_input_dim(
input_dim = 0

sparse_feature_columns = list(
filter(
lambda x: isinstance(x, (SparseFeat, VarLenSparseFeat)), feature_columns
)
filter(lambda x: isinstance(x, SparseFeat), feature_columns)
if len(feature_columns)
else []
)
Expand Down Expand Up @@ -432,7 +415,10 @@ def add_regularization_weight(
:param l1: The lambda value determining the strength of L1 regularization.
:param l2: The lambda value determining the strength of L2 regularization.
"""
weight_list = [weight_list]
if isinstance(weight_list, torch.nn.parameter.Parameter):
weight_list = [weight_list]
else:
weight_list = list(weight_list)
self.regularization_weight.append((weight_list, l1, l2))

def get_regularization_loss(self) -> torch.Tensor:
Expand Down
126 changes: 126 additions & 0 deletions CATS/models/dnn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import List, Literal, Union

import torch
import torch.nn as nn

from ..inputs import (DenseFeat, SparseFeat, VarLenSparseFeat,
create_embedding_matrix)
from ..layers import DNN
from .basemodel import BaseModel


class DNNModel(BaseModel):
def __init__(
self,
linear_feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]],
dnn_feature_columns: List[Union[SparseFeat, DenseFeat, VarLenSparseFeat]],
dnn_hidden_units=(128, 128),
l2_reg_linear: float = 1e-5,
l2_reg_embedding: float = 1e-5,
l2_reg_dnn: float = 0,
init_std: float = 0.0001,
seed: int = 1024,
dnn_dropout: float = 0,
dnn_activation: Union[
Literal["sigmoid", "relu", "prelu", "identity"], nn.Module
] = "relu",
dnn_use_bn: bool = False,
task: Literal["binary", "multiclass", "regression"] = "binary",
device: Literal["cpu", "cuda", "mps"] = "cpu",
):
"""
simple dnn model.
:param linear_feature_columns: list of features attributes for linear model.
:param dnn_feature_columns: list of features attributes for dnn model.
:param dnn_hidden_units: dnn hidden unit's output and input size
:param l2_reg_linear: L2 regularization for linear features
:param l2_reg_embedding: L2 regularization for embedding features
:param l2_reg_dnn: L2 regularization for dnn parameters
:param init_std: initialize standard deviation
:param seed: random seed value
:param dnn_dropout: dnn's dropout rate
:param dnn_activation: dnn's activation function
:param dnn_use_bn: if dnn using bn, it's true else false
:param task: object task
:param device: target device
"""
super(DNNModel, self).__init__(
linear_feature_columns=linear_feature_columns,
dnn_feature_columns=dnn_feature_columns,
l2_reg_embedding=l2_reg_embedding,
init_std=init_std,
seed=seed,
task=task,
device=device,
)

self.dnn_hidden_units = dnn_hidden_units
self.dnn = DNN(
self._compute_input_dim(dnn_feature_columns),
dnn_hidden_units,
activation=dnn_activation,
use_bn=dnn_use_bn,
l2_reg=l2_reg_dnn,
dropout_rate=dnn_dropout,
init_std=init_std,
device=device,
)

self.sparse_feature_columns = (
list(filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns))
if len(dnn_feature_columns)
else []
)
self.dense_feature_columns = (
list(filter(lambda x: isinstance(x, DenseFeat), dnn_feature_columns))
if len(dnn_feature_columns)
else []
)

dnn_linear_in_feature = dnn_hidden_units[-1]
self.dnn_linear = nn.Linear(dnn_linear_in_feature, 1, bias=False).to(device)
self.add_regularization_weight(
filter(
lambda x: "weight" in x[0] and "bn" not in x[0],
self.dnn.named_parameters(),
),
l2=l2_reg_dnn,
)
self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_linear)
self.to(device)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
feed-forward
:param inputs: inputs batch train data
:return: predict value
"""
sparse_embedding_list = [
self.embedding_dict[feat.embedding_name](
inputs[
:,
self.feature_index[feat.name][0] : self.feature_index[feat.name][1],
].long()
)
for feat in self.sparse_feature_columns
]

dense_value_list = [
inputs[
:, self.feature_index[feat.name][0] : self.feature_index[feat.name][1]
]
for feat in self.dense_feature_columns
]

sparse_dnn_input = torch.flatten(
torch.cat(sparse_embedding_list, dim=-1), start_dim=1
)
dense_dnn_input = torch.flatten(
torch.cat(dense_value_list, dim=-1), start_dim=1
)
dnn_input = torch.cat([sparse_dnn_input, dense_dnn_input], dim=-1)

dnn_out = self.dnn(dnn_input)
logit = self.dnn_linear(dnn_out)
y_pred = self.out(logit)
return y_pred
Loading
Loading