Skip to content

Commit

Permalink
add h2o schema (#413)
Browse files Browse the repository at this point in the history
* add `H2OGradientBoostingEstimator` schema and h2o schema boilerplate
* add `H2OGeneralizedLinearEstimator` schema
* add `H2ORandomForestEstimator` schema
* add `H2OTargetEncoderEstimator` schema
* add `H2OXGBoostEstimator` schema
* log models as artifacts
* test custom artifact logging
* fix other test
  • Loading branch information
ryanSoley authored Mar 5, 2024
1 parent 6168388 commit 02e9842
Show file tree
Hide file tree
Showing 9 changed files with 666 additions and 5 deletions.
12 changes: 9 additions & 3 deletions rubicon_ml/schema/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,15 @@ def log_with_schema(
if artifact == "self":
experiment.log_artifact(name=obj.__class__.__name__, data_object=obj)
elif isinstance(artifact, dict):
data_object = _get_data_object(obj, artifact)
if data_object is not None:
experiment.log_artifact(name=artifact["name"], data_object=data_object)
if "self" in artifact:
logging_func_name = artifact["self"]
logging_func = getattr(experiment, logging_func_name)
logging_func(obj)
else:
data_object = _get_data_object(obj, artifact)

if data_object is not None:
experiment.log_artifact(name=artifact["name"], data_object=data_object)

for dataframe in self.schema_.get("dataframes", []):
df_value = _get_df(obj, dataframe)
Expand Down
15 changes: 15 additions & 0 deletions rubicon_ml/schema/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,21 @@
import yaml

RUBICON_SCHEMA_REGISTRY = {
"h2o__H2OGeneralizedLinearEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OGeneralizedLinearEstimator.yaml")
),
"h2o__H2OGradientBoostingEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OGradientBoostingEstimator.yaml")
),
"h2o__H2ORandomForestEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2ORandomForestEstimator.yaml")
),
"h2o__H2OTargetEncoderEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OTargetEncoderEstimator.yaml")
),
"h2o__H2OXGBoostEstimator": lambda: _load_schema(
os.path.join("schema", "h2o__H2OXGBoostEstimator.yaml")
),
"lightgbm__LGBMModel": lambda: _load_schema(os.path.join("schema", "lightgbm__LGBMModel.yaml")),
"lightgbm__LGBMClassifier": lambda: _load_schema(
os.path.join("schema", "lightgbm__LGBMClassifier.yaml")
Expand Down
158 changes: 158 additions & 0 deletions rubicon_ml/schema/schema/h2o__H2OGeneralizedLinearEstimator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
name: h2o__H2OGeneralizedLinearEstimator
version: 1.0.0

compatibility:
lightgbm:
max_version:
min_version: 3.44.0.1
docs_url: https://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/modeling.html#h2ogeneralizedlinearestimator

artifacts:
- self: log_h2o_model
parameters:
- name: alpha
value_attr: alpha
- name: auc_type
value_attr: auc_type
- name: balance_classes
value_attr: balance_classes
- name: beta_constraints
value_attr: beta_constraints
- name: beta_epsilon
value_attr: beta_epsilon
- name: build_null_model
value_attr: build_null_model
- name: calc_like
value_attr: calc_like
- name: class_sampling_factors
value_attr: class_sampling_factors
- name: cold_start
value_attr: cold_start
- name: compute_p_values
value_attr: compute_p_values
- name: custom_metric_func
value_attr: custom_metric_func
- name: dispersion_epsilon
value_attr: dispersion_epsilon
- name: dispersion_learning_rate
value_attr: dispersion_learning_rate
- name: dispersion_parameter_method
value_attr: dispersion_parameter_method
- name: early_stopping
value_attr: early_stopping
- name: export_checkpoints_dir
value_attr: export_checkpoints_dir
- name: family
value_attr: family
- name: fix_dispersion_parameter
value_attr: fix_dispersion_parameter
- name: fix_tweedie_variance_power
value_attr: fix_tweedie_variance_power
- name: fold_assignment
value_attr: fold_assignment
- name: fold_column
value_attr: fold_column
- name: gainslift_bins
value_attr: gainslift_bins
- name: generate_scoring_history
value_attr: generate_scoring_history
- name: generate_variable_inflation_factors
value_attr: generate_variable_inflation_factors
- name: gradient_epsilon
value_attr: gradient_epsilon
- name: HGLM
value_attr: HGLM
- name: ignore_const_cols
value_attr: ignore_const_cols
- name: ignored_columns
value_attr: ignored_columns
- name: influence
value_attr: influence
- name: init_dispersion_parameter
value_attr: init_dispersion_parameter
- name: interaction_pairs
value_attr: interaction_pairs
- name: interactions
value_attr: interactions
- name: intercept
value_attr: intercept
- name: keep_cross_validation_fold_assignment
value_attr: keep_cross_validation_fold_assignment
- name: keep_cross_validation_models
value_attr: keep_cross_validation_models
- name: keep_cross_validation_predictions
value_attr: keep_cross_validation_predictions
- name: lambda_
value_attr: lambda_
- name: lambda_min_ratio
value_attr: lambda_min_ratio
- name: lambda_search
value_attr: lambda_search
- name: link
value_attr: link
- name: max_active_predictors
value_attr: max_active_predictors
- name: max_after_balance_size
value_attr: max_after_balance_size
- name: max_confusion_matrix_size
value_attr: max_confusion_matrix_size
- name: max_iterations
value_attr: max_iterations
- name: max_iterations_dispersion
value_attr: max_iterations_dispersion
- name: max_runtime_secs
value_attr: max_runtime_secs
- name: missing_values_handling
value_attr: missing_values_handling
- name: nfolds
value_attr: nfolds
- name: nlambdas
value_attr: nlambdas
- name: non_negative
value_attr: non_negative
- name: obj_reg
value_attr: obj_reg
- name: objective_epsilon
value_attr: objective_epsilon
- name: offset_column
value_attr: offset_column
- name: prior
value_attr: prior
- name: rand_family
value_attr: rand_family
- name: rand_link
value_attr: rand_link
- name: random_columns
value_attr: random_columns
- name: remove_collinear_columns
value_attr: remove_collinear_columns
- name: response_column
value_attr: response_column
- name: score_each_iteration
value_attr: score_each_iteration
- name: score_iteration_interval
value_attr: score_iteration_interval
- name: seed
value_attr: seed
- name: solver
value_attr: solver
- name: standardize
value_attr: standardize
- name: startval
value_attr: startval
- name: stopping_metric
value_attr: stopping_metric
- name: stopping_rounds
value_attr: stopping_rounds
- name: stopping_tolerance
value_attr: stopping_tolerance
- name: theta
value_attr: theta
- name: tweedie_epsilon
value_attr: tweedie_epsilon
- name: tweedie_link_power
value_attr: tweedie_link_power
- name: tweedie_variance_power
value_attr: tweedie_variance_power
- name: weights_column
value_attr: weights_column
130 changes: 130 additions & 0 deletions rubicon_ml/schema/schema/h2o__H2OGradientBoostingEstimator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
name: h2o__H2OGradientBoostingEstimator
version: 1.0.0

compatibility:
lightgbm:
max_version:
min_version: 3.44.0.1
docs_url: https://docs.h2o.ai/h2o/latest-stable/h2o-py/docs/modeling.html#h2ogradientboostingestimator

artifacts:
- self: log_h2o_model
parameters:
- name: auc_type
value_attr: auc_type
- name: auto_rebalance
value_attr: auto_rebalance
- name: balance_classes
value_attr: balance_classes
- name: build_tree_one_node
value_attr: build_tree_one_node
- name: calibrate_model
value_attr: calibrate_model
- name: calibration_method
value_attr: calibration_method
- name: categorical_encoding
value_attr: categorical_encoding
- name: check_constant_response
value_attr: check_constant_response
- name: class_sampling_factors
value_attr: class_sampling_factors
- name: col_sample_rate
value_attr: col_sample_rate
- name: col_sample_rate_change_per_level
value_attr: col_sample_rate_change_per_level
- name: col_sample_rate_per_tree
value_attr: col_sample_rate_per_tree
- name: custom_distribution_func
value_attr: custom_distribution_func
- name: custom_metric_func
value_attr: custom_metric_func
- name: distribution
value_attr: distribution
- name: export_checkpoints_dir
value_attr: export_checkpoints_dir
- name: fold_assignment
value_attr: fold_assignment
- name: fold_column
value_attr: fold_column
- name: gainslift_bins
value_attr: gainslift_bins
- name: histogram_type
value_attr: histogram_type
- name: huber_alpha
value_attr: huber_alpha
- name: ignore_const_cols
value_attr: ignore_const_cols
- name: ignored_columns
value_attr: ignored_columns
- name: in_training_checkpoints_dir
value_attr: in_training_checkpoints_dir
- name: in_training_checkpoints_tree_interval
value_attr: in_training_checkpoints_tree_interval
- name: interaction_constraints
value_attr: interaction_constraints
- name: keep_cross_validation_fold_assignment
value_attr: keep_cross_validation_fold_assignment
- name: keep_cross_validation_models
value_attr: keep_cross_validation_models
- name: keep_cross_validation_predictions
value_attr: keep_cross_validation_predictions
- name: learn_rate
value_attr: learn_rate
- name: learn_rate_annealing
value_attr: learn_rate_annealing
- name: max_abs_leafnode_pred
value_attr: max_abs_leafnode_pred
- name: max_after_balance_size
value_attr: max_after_balance_size
- name: max_confusion_matrix_size
value_attr: max_confusion_matrix_size
- name: max_depth
value_attr: max_depth
- name: max_runtime_secs
value_attr: max_runtime_secs
- name: min_rows
value_attr: min_rows
- name: min_split_improvement
value_attr: min_split_improvement
- name: monotone_constraints
value_attr: monotone_constraints
- name: nbins
value_attr: nbins
- name: nbins_cats
value_attr: nbins_cats
- name: nbins_top_level
value_attr: nbins_top_level
- name: nfolds
value_attr: nfolds
- name: ntrees
value_attr: ntrees
- name: offset_column
value_attr: offset_column
- name: pred_noise_bandwidth
value_attr: pred_noise_bandwidth
- name: quantile_alpha
value_attr: quantile_alpha
- name: r2_stopping
value_attr: r2_stopping
- name: response_column
value_attr: response_column
- name: sample_rate
value_attr: sample_rate
- name: sample_rate_per_class
value_attr: sample_rate_per_class
- name: score_each_iteration
value_attr: score_each_iteration
- name: score_tree_interval
value_attr: score_tree_interval
- name: seed
value_attr: seed
- name: stopping_metric
value_attr: stopping_metric
- name: stopping_rounds
value_attr: stopping_rounds
- name: stopping_tolerance
value_attr: stopping_tolerance
- name: tweedie_power
value_attr: tweedie_power
- name: weights_column
value_attr: weights_column
Loading

0 comments on commit 02e9842

Please sign in to comment.