Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

[HIVEMALL-233] RandomForest regressor accepts sparse vector input #178

Closed
wants to merge 8 commits into from

Conversation

takuti
Copy link
Member

@takuti takuti commented Jan 18, 2019

What changes were proposed in this pull request?

Enable RandomForestRegressor to accept sparse vector input as RandomForestClassifier already does.

#51 made incomplete modifications on RandomForestRegressor, while its classifier counterpart has been properly updated.

What type of PR is it?

Improvement

What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-233

How was this patch tested?

Unit tests and manual test with tiny sample data:

with customers as (
  select 1 as id, "male" as gender, 23 as age, "Japan" as country, 12 as num_purchases
  union all
  select 2 as id, "female" as gender, 43 as age, "US" as country, 4 as num_purchases
  union all
  select 3 as id, "other" as gender, 19 as age, "UK" as country, 2 as num_purchases
  union all
  select 4 as id, "male" as gender, 31 as age, "US" as country, 20 as num_purchases
  union all
  select 5 as id, "female" as gender, 37 as age, "Australia" as country, 9 as num_purchases
),
training as (
  select
    array_concat(
      quantitative_features(
        array("age"),
        age
      ),
      categorical_features(
        array("country", "gender"),
        country, gender
      )
    ) as features,
    num_purchases
  from
    customers
)
select
  train_randomforest_regressor(
    feature_hashing(features), -- feature vector
    num_purchases, -- target value
    '-trees 40 -seed 31' -- hyper-parameters
  )
from
  training
;

Checklist

  • Did you apply source code formatter, i.e., ./bin/format_code.sh, for your commit?
  • Did you run system tests on Hive (or Spark)?

@@ -70,7 +70,7 @@ public DoubleArrayList add(@Nonnull double[] values) {
private void expand(int max) {
while (data.length < max) {
final int len = data.length;
double[] newArray = new double[len * 2];
double[] newArray = new double[(len + 1) * 2];
Copy link
Member

@myui myui Jan 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuti Why this change has been made?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See d7695d4

In case # of samples is less than 10, SmileExtUtils#sort calls DoubleArrayList(0), and zero-sized array list falls into infinite loop in the expand method as newArray = new double[0 * 2]. The change prevents this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, this PR does not resolves a potential bug in expand. Returning array should be >= max and max should be minCapacity where expand's argument is expected to be >=1.

https://github.com/karussell/fastutil/blob/master/src/it/unimi/dsi/fastutil/doubles/DoubleArrayList.java#L203
https://github.com/karussell/fastutil/blob/master/src/it/unimi/dsi/fastutil/doubles/DoubleArrays.java#L136

Let me fix this in another PR.

Copy link
Member

@myui myui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/incubator-hivemall/pull/178/files#diff-c4461aa6ca93a5747f6326ce6769f810L230

@takuti OutputObjectInspector should be configured as follows:

        if (denseInput) {
            fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        } else {
            fieldOIs.add(ObjectInspectorFactory.getStandardMapObjectInspector(
                PrimitiveObjectInspectorFactory.writableIntObjectInspector,
                PrimitiveObjectInspectorFactory.writableDoubleObjectInspector));
        }

I'll fix it another PR.

@asfgit asfgit closed this in 2e1104c Feb 5, 2019
@takuti takuti deleted the rf-reg-stability branch February 5, 2019 08:03
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants