Skip to content

Commit

Permalink
using dataclass for inputs (#472)
Browse files Browse the repository at this point in the history
* using dataclass for inputs

* error correction

* comments as per reviews
  • Loading branch information
joker2411 authored Oct 17, 2024
1 parent 0aa29ae commit 251c94b
Show file tree
Hide file tree
Showing 13 changed files with 389 additions and 261 deletions.
29 changes: 17 additions & 12 deletions src/predictions/profiles_mlcorelib/connectors/Connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Iterable, List, Tuple, Union, Sequence, Optional, Dict, Set

from ..utils.logger import logger
from ..utils import utils


class Connector(ABC):
Expand Down Expand Up @@ -59,11 +60,13 @@ def _validate_common_columns(
f"Common columns are present in 2 or more inputs. Please correct the inputs in config."
)

def get_input_columns(self, trainer_obj, inputs):
def get_input_columns(
self, trainer_obj, inputs: List[utils.InputsConfig]
) -> List[str]:
columns_per_input = list()

for input_ in inputs:
query = input_["selector_sql"] + " LIMIT 1"
query = input_.selector_sql + " LIMIT 1"
ind_input_columns = set(self.run_query(query)[0]._fields)
ind_input_columns.difference_update(
{
Expand All @@ -81,16 +84,16 @@ def get_input_columns(self, trainer_obj, inputs):
input_columns = set.union(*columns_per_input)
return list(input_columns)

def _get_table_info(self, inputs):
def _get_table_info(self, inputs: List[utils.InputsConfig]):
tables = OrderedDict()
for input_ in inputs:
table_name = input_["table_name"]
table_name = input_.table_name
if table_name not in tables:
tables[table_name] = {
"column_name": [],
}
if input_["column_name"]:
tables[table_name]["column_name"].append(input_["column_name"])
if input_.column_name:
tables[table_name]["column_name"].append(input_.column_name)

return tables

Expand Down Expand Up @@ -120,7 +123,7 @@ def _construct_join_query(self, entity_column, input_columns, tables):

def join_input_tables(
self,
inputs: List[Dict],
inputs: List[utils.InputsConfig],
input_columns: List[str],
entity_column: str,
temp_joined_input_table_name: str,
Expand All @@ -139,7 +142,7 @@ def get_input_column_types(
self,
trainer_obj,
input_columns: List[str],
inputs: List[dict],
inputs: List[utils.InputsConfig],
table_name: str,
) -> Tuple:
"""Returns a dictionary containing the input column types with keys (numeric, categorical, arraytype, timestamp, booleantype) for a given table."""
Expand Down Expand Up @@ -242,16 +245,18 @@ def get_all_columns_of_a_type(
trainer_obj.prep.ignore_features = ignore_features
return updated_input_column_types

def check_arraytype_conflicts(self, updated_input_column_types, inputs):
def check_arraytype_conflicts(
self, updated_input_column_types, inputs: List[utils.InputsConfig]
):
arraytype_columns = updated_input_column_types.get("arraytype", [])

for column in arraytype_columns:
column_lower = column.lower()

for input in inputs:
for input_ in inputs:
if (
input["column_name"] is not None
and column_lower == input["column_name"].lower()
input_.column_name is not None
and column_lower == input_.column_name.lower()
):
raise Exception(
f"Array type features are not supported. Please remove '{column_lower}' and any other array type features from inputs."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,19 @@ def predict_scores_rs(df: pd.DataFrame) -> pd.DataFrame:

model_path = os.path.join(output_dir, args.json_output_filename)

inputs_info: List[utils.InputsConfig] = []
try:
for input_ in args.inputs:
inputs_info.append(utils.InputsConfig(**input_))
except Exception as e:
logger.get().error(f"Error while parsing inputs: {e}")
raise Exception(f"Error while parsing inputs: {e}")

_ = preprocess_and_predict(
wh_creds,
args.s3_config,
model_path,
args.inputs,
inputs_info,
args.end_ts,
args.output_tablename,
connector=connector,
Expand Down
2 changes: 1 addition & 1 deletion src/predictions/profiles_mlcorelib/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def _predict(
creds: dict,
model_path: str,
inputs: List[dict],
inputs: List[utils.InputsConfig],
output_tablename: str,
config: dict,
runtime_info: dict,
Expand Down
6 changes: 4 additions & 2 deletions src/predictions/profiles_mlcorelib/processors/K8sProcessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import uuid
import time
from typing import List
from dataclasses import asdict
from kubernetes import client, config, watch
import base64
import sys
Expand All @@ -12,6 +13,7 @@
from ..utils.logger import logger
from ..utils.S3Utils import S3Utils
from ..utils.constants import TrainTablesInfo
from ..utils import utils


class K8sProcessor(Processor):
Expand Down Expand Up @@ -269,7 +271,7 @@ def predict(
wh_creds: dict,
s3_config: dict,
model_path: str,
inputs: List[dict],
inputs: List[utils.InputsConfig],
end_ts: str,
output_tablename: str,
merged_config: dict,
Expand Down Expand Up @@ -305,7 +307,7 @@ def predict(
"--json_output_filename",
json_output_filename,
"--inputs",
json.dumps(inputs),
json.dumps(inputs, default=asdict),
"--end_ts",
end_ts,
"--output_tablename",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
from typing import List
from dataclasses import asdict
import sys

from ..utils import utils
Expand Down Expand Up @@ -85,7 +86,7 @@ def predict(
"--json_output_filename",
json_output_filename,
"--inputs",
json.dumps(inputs),
json.dumps(inputs, default=asdict),
"--end_ts",
end_ts,
"--output_tablename",
Expand Down
4 changes: 2 additions & 2 deletions src/predictions/profiles_mlcorelib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

def _train(
creds: dict,
inputs: List[dict],
inputs: List[utils.InputsConfig],
output_filename: str,
config: dict,
site_config_path: str,
Expand All @@ -55,7 +55,7 @@ def _train(
Args:
creds (dict): credentials to access the data warehouse - in same format as site_config.yaml from profiles
inputs (List[dict]): list of input models
inputs (List[utils.InputsConfig]): list of input models
output_filename (str): path to the file where the model details including model id etc are written. Used in prediction step.
config (dict): configs from profiles.yaml which should overwrite corresponding values from model_configs.yaml file
Expand Down
15 changes: 14 additions & 1 deletion src/predictions/profiles_mlcorelib/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
from typing import Tuple, List, Dict
from typing import Tuple, List, Dict, Optional

import snowflake.snowpark
from snowflake.snowpark.session import Session
Expand Down Expand Up @@ -76,6 +76,19 @@ class OutputsConfig:
feature_meta_data: List[dict]


@dataclass
class InputsConfig:
"""InputsConfig class is used to store the inputs configuration parameters"""

table_name: str
model_ref: str
model_type: str
selector_sql: str
model_name: str
model_hash: str
column_name: Optional[str] = None


def split_train_test(
feature_df: pd.DataFrame,
label_column: str,
Expand Down
48 changes: 25 additions & 23 deletions src/predictions/profiles_mlcorelib/wht/pyNativeWHT.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_material_names(
start_date: str,
end_date: str,
prediction_horizon_days: int,
inputs: List[dict],
inputs: List[utils.InputsConfig],
input_columns: List[str],
entity_column: str,
feature_data_min_date_diff: int,
Expand Down Expand Up @@ -122,17 +122,17 @@ def compute_material_name(
def get_registry_table_name(self) -> str:
return self.pythonWHT.get_registry_table_name()

def get_latest_seq_no(self, inputs: List[dict]) -> int:
def get_latest_seq_no(self, inputs: List[utils.InputsConfig]) -> int:
return self.pythonWHT.get_latest_seq_no(inputs)

def get_inputs(self, input_model_refs: List[str]) -> List[dict]:
inputs = []
for input in input_model_refs:
material = self.whtMaterial.de_ref(input)
def get_inputs(self, input_model_refs: List[str]) -> List[utils.InputsConfig]:
inputs: List[utils.InputsConfig] = []
for input_model_ref in input_model_refs:
material = self.whtMaterial.de_ref(input_model_ref)
# if material.model.model_type() == "sql_template":
# material = self.whtMaterial.de_ref(input + "/var_table")
# material = self.whtMaterial.de_ref(input_model_ref + "/var_table")
# id_column_name = self.whtMaterial.model.entity()["IdColumnName"]
# self.whtMaterial.de_ref(input + f"/var_table/{id_column_name}")
# self.whtMaterial.de_ref(input_model_ref + f"/var_table/{id_column_name}")
column_name = None
if material.model.materialization()["output_type"] == "column":
column_name = material.model.db_object_name_prefix()
Expand All @@ -141,24 +141,26 @@ def get_inputs(self, input_model_refs: List[str]) -> List[dict]:
else:
table_material = material
material_name_dict = self.pythonWHT.split_material_name(material.name())
inputs.append(
{
"table_name": table_material.name(),
"model_ref": material.model.model_ref(),
"model_type": material.model.model_type(),
"selector_sql": material.get_selector_sql(),
"column_name": column_name,
"model_name": material_name_dict["model_name"],
"model_hash": material_name_dict["model_hash"],
}
)

input = {
"table_name": table_material.name(),
"model_ref": material.model.model_ref(),
"model_type": material.model.model_type(),
"selector_sql": material.get_selector_sql(),
"column_name": column_name,
"model_name": material_name_dict["model_name"],
"model_hash": material_name_dict["model_hash"],
}
inputs.append(utils.InputsConfig(**input))
return inputs

def validate_sql_table(self, inputs, entity_column) -> None:
def validate_sql_table(
self, inputs: List[utils.InputsConfig], entity_column: str
) -> None:
for input in inputs:
if input["model_type"] == "sql_template":
if input.model_type == "sql_template":
self.pythonWHT.connector.validate_sql_table(
input["table_name"], entity_column
input.table_name, entity_column
)

def get_credentials(self, project_path: str, site_config_path: str) -> str:
Expand All @@ -172,7 +174,7 @@ def check_and_generate_more_materials(
self,
get_material_func: callable,
materials: List[TrainTablesInfo],
inputs: List[dict],
inputs: List[utils.InputsConfig],
trainer: MLTrainer,
):
return self.pythonWHT.check_and_generate_more_materials(
Expand Down
Loading

0 comments on commit 251c94b

Please sign in to comment.