Skip to content

Commit

Permalink
feat: add classification functions (facebookincubator#11792)
Browse files Browse the repository at this point in the history
Summary:

Add the classification functions from presto into velox: https://prestodb.io/docs/current/functions/aggregate.html#classification-metrics-aggregate-functions

Classification functions all use `FixedDoubleHistogram`, which is a data structure to represent the bucket of weights. The index of the bucket for the histogram is evenly distributed between the min and value values.

For all of the classification functions, the only difference is the extraction phase. All other steps will be the same.

At a high level:
- addRawInput will add a value into either the true or false weight bucket. The bucket to add the value to will depend on the prediction value. The prediction value is linearly mapped into a bucket based on (min, max and bucketCount) by normalizing the prediction between min and max.

- The schema of the intermediate states is [version header][bucket count][min][max][weights]

Reviewed By: Yuhta

Differential Revision: D66684198
  • Loading branch information
yuandagits authored and facebook-github-bot committed Dec 14, 2024
1 parent 1779351 commit 63d36d6
Show file tree
Hide file tree
Showing 11 changed files with 1,255 additions and 7 deletions.
195 changes: 195 additions & 0 deletions velox/docs/functions/presto/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,201 @@ __ https://www.cse.ust.hk/~raywong/comp5331/References/EfficientComputationOfFre
As ``approx_percentile(x, w, percentages)``, but with a maximum rank error
of ``accuracy``.

Classification Metrics Aggregate Functions
------------------------------------------

The following functions each measure how some metric of a binary
`confusion matrix <https://en.wikipedia.org/wiki/Confusion_matrix>`_ changes as a function of
classification thresholds. They are meant to be used in conjunction.

For example, to find the `precision-recall curve <https://en.wikipedia.org/wiki/Precision_and_recall>`_, use

.. code-block:: none
WITH
recall_precision AS (
SELECT
CLASSIFICATION_RECALL(10000, correct, pred) AS recalls,
CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions
FROM
classification_dataset
)
SELECT
recall,
precision
FROM
recall_precision
CROSS JOIN UNNEST(recalls, precisions) AS t(recall, precision)
To get the corresponding thresholds for these values, use

.. code-block:: none
WITH
recall_precision AS (
SELECT
CLASSIFICATION_THRESHOLDS(10000, correct, pred) AS thresholds,
CLASSIFICATION_RECALL(10000, correct, pred) AS recalls,
CLASSIFICATION_PRECISION(10000, correct, pred) AS precisions
FROM
classification_dataset
)
SELECT
threshold,
recall,
precision
FROM
recall_precision
CROSS JOIN UNNEST(thresholds, recalls, precisions) AS t(threshold, recall, precision)
To find the `ROC curve <https://en.wikipedia.org/wiki/Receiver_operating_characteristic>`_, use

.. code-block:: none
WITH
fallout_recall AS (
SELECT
CLASSIFICATION_FALLOUT(10000, correct, pred) AS fallouts,
CLASSIFICATION_RECALL(10000, correct, pred) AS recalls
FROM
classification_dataset
)
SELECT
fallout
recall,
FROM
recall_fallout
CROSS JOIN UNNEST(fallouts, recalls) AS t(fallout, recall)
.. function:: classification_miss_rate(buckets, y, x, weight) -> array<double>

Computes the miss-rate with up to ``buckets`` number of buckets. Returns
an array of miss-rate values.

``y`` should be a boolean outcome value; ``x`` should be predictions, each
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.

The
`miss-rate <https://en.wikipedia.org/wiki/Type_I_and_type_II_errors#False_positive_and_false_negative_rates>`_
is defined as a sequence whose :math:`j`-th entry is

.. math ::
{
\sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right]
\over
\sum_{i \;|\; x_i \leq t_j \bigwedge y_i = 1} \left[ w_i \right]
+
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right]
},
where :math:`t_j` is the :math:`j`-th smallest threshold,
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
entries of ``y``, ``x``, and ``weight``, respectively.

.. function:: classification_miss_rate(buckets, y, x) -> array<double>

This function is equivalent to the variant of
:func:`!classification_miss_rate` that takes a ``weight``, with a per-item weight of ``1``.

.. function:: classification_fall_out(buckets, y, x, weight) -> array<double>

Computes the fall-out with up to ``buckets`` number of buckets. Returns
an array of fall-out values.

``y`` should be a boolean outcome value; ``x`` should be predictions, each
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.

The
`fall-out <https://en.wikipedia.org/wiki/Information_retrieval#Fall-out>`_
is defined as a sequence whose :math:`j`-th entry is

.. math ::
{
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 0} \left[ w_i \right]
\over
\sum_{i \;|\; y_i = 0} \left[ w_i \right]
},
where :math:`t_j` is the :math:`j`-th smallest threshold,
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
entries of ``y``, ``x``, and ``weight``, respectively.

.. function:: classification_fall_out(buckets, y, x) -> array<double>

This function is equivalent to the variant of
:func:`!classification_fall_out` that takes a ``weight``, with a per-item weight of ``1``.

.. function:: classification_precision(buckets, y, x, weight) -> array<double>

Computes the precision with up to ``buckets`` number of buckets. Returns
an array of precision values.

``y`` should be a boolean outcome value; ``x`` should be predictions, each
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.

The
`precision <https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values>`_
is defined as a sequence whose :math:`j`-th entry is

.. math ::
{
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right]
\over
\sum_{i \;|\; x_i > t_j} \left[ w_i \right]
},
where :math:`t_j` is the :math:`j`-th smallest threshold,
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
entries of ``y``, ``x``, and ``weight``, respectively.

.. function:: classification_precision(buckets, y, x) -> array<double>

This function is equivalent to the variant of
:func:`!classification_precision` that takes a ``weight``, with a per-item weight of ``1``.

.. function:: classification_recall(buckets, y, x, weight) -> array<double>

Computes the recall with up to ``buckets`` number of buckets. Returns
an array of recall values.

``y`` should be a boolean outcome value; ``x`` should be predictions, each
between 0 and 1; ``weight`` should be non-negative values, indicating the weight of the instance.

The
`recall <https://en.wikipedia.org/wiki/Precision_and_recall#Recall>`_
is defined as a sequence whose :math:`j`-th entry is

.. math ::
{
\sum_{i \;|\; x_i > t_j \bigwedge y_i = 1} \left[ w_i \right]
\over
\sum_{i \;|\; y_i = 1} \left[ w_i \right]
},
where :math:`t_j` is the :math:`j`-th smallest threshold,
and :math:`y_i`, :math:`x_i`, and :math:`w_i` are the :math:`i`-th
entries of ``y``, ``x``, and ``weight``, respectively.

.. function:: classification_recall(buckets, y, x) -> array<double>

This function is equivalent to the variant of
:func:`!classification_recall` that takes a ``weight``, with a per-item weight of ``1``.

.. function:: classification_thresholds(buckets, y, x) -> array<double>

Computes the thresholds with up to ``buckets`` number of buckets. Returns
an array of threshold values.

``y`` should be a boolean outcome value; ``x`` should be predictions, each
between 0 and 1.

The thresholds are defined as a sequence whose :math:`j`-th entry is the :math:`j`-th smallest threshold.

Statistical Aggregate Functions
-------------------------------

Expand Down
10 changes: 5 additions & 5 deletions velox/docs/functions/presto/coverage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ Here is a list of all scalar and aggregate Presto functions with functions that
:func:`array_duplicates` :func:`dow` :func:`json_extract` :func:`repeat` st_union :func:`bool_and` :func:`rank`
:func:`array_except` :func:`doy` :func:`json_extract_scalar` :func:`replace` st_within :func:`bool_or` :func:`row_number`
:func:`array_frequency` :func:`e` :func:`json_format` replace_first st_x :func:`checksum`
:func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax classification_fall_out
:func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin classification_miss_rate
:func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y classification_precision
array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax classification_recall
:func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin classification_thresholds
:func:`array_has_duplicates` :func:`element_at` :func:`json_parse` :func:`reverse` st_xmax :func: `classification_fall_out`
:func:`array_intersect` :func:`empty_approx_set` :func:`json_size` rgb st_xmin :func: `classification_miss_rate`
:func:`array_join` :func:`ends_with` key_sampling_percent :func:`round` st_y :func: `classification_precision`
array_least_frequent enum_key :func:`laplace_cdf` :func:`rpad` st_ymax :func: `classification_recall`
:func:`array_max` :func:`exp` :func:`last_day_of_month` :func:`rtrim` st_ymin :func: `classification_thresholds`
array_max_by expand_envelope :func:`least` scale_qdigest :func:`starts_with` convex_hull_agg
:func:`array_min` :func:`f_cdf` :func:`length` :func:`second` :func:`strpos` :func:`corr`
array_min_by features :func:`levenshtein_distance` secure_rand :func:`strrpos` :func:`count`
Expand Down
5 changes: 5 additions & 0 deletions velox/functions/prestosql/aggregates/AggregateNames.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ const char* const kBitwiseXor = "bitwise_xor_agg";
const char* const kBoolAnd = "bool_and";
const char* const kBoolOr = "bool_or";
const char* const kChecksum = "checksum";
const char* const kClassificationFallout = "classification_fall_out";
const char* const kClassificationPrecision = "classification_precision";
const char* const kClassificationRecall = "classification_recall";
const char* const kClassificationMissRate = "classification_miss_rate";
const char* const kClassificationThreshold = "classification_thresholds";
const char* const kCorr = "corr";
const char* const kCount = "count";
const char* const kCountIf = "count_if";
Expand Down
1 change: 1 addition & 0 deletions velox/functions/prestosql/aggregates/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ velox_add_library(
CountIfAggregate.cpp
CovarianceAggregates.cpp
ChecksumAggregate.cpp
ClassificationAggregation.cpp
EntropyAggregates.cpp
GeometricMeanAggregate.cpp
HistogramAggregate.cpp
Expand Down
Loading

0 comments on commit 63d36d6

Please sign in to comment.