Skip to content

Commit

Permalink
Improve the accuracy of complex atanh and atan (#2513)
Browse files Browse the repository at this point in the history
As in the title.

The accuracy improvements are as follows (using ca 1000000 samples over
the entire complex plane):

- complex64 atanh:
  ```
                           current main  -> this PR
  ULP difference == 0 count is 257275    -> 868955
  ULP difference == 1 count is 24034     -> 130206
  ULP difference == 2 count is 1280      -> 2832
  ULP difference == 3 count is 2903      -> 12
  ULP difference >= 4 count is 716513    -> 0
  ```
- complex64 atan:
  ```
  ULP difference == 0 count is 3326      -> 868953
  ULP difference == 1 count is 6970      -> 130204
  ULP difference == 2 count is 3384      -> 2832
  ULP difference == 3 count is 3418      -> 12
  ULP difference >= 4 count is 984903    -> 0
  ```
- complex128 atanh:
  ```
  ULP difference == 0 count is 239712    -> 941283
  ULP difference == 1 count is 2853      -> 60322
  ULP difference == 2 count is 616       -> 400
  ULP difference == 3 count is 8         -> 0
  ULP difference >= 4 count is 758816    -> 0
  ```

This PR requires functional_algorithms 0.10.1 or newer.
  • Loading branch information
pearu authored Sep 5, 2024
1 parent 4f31b2e commit 932b32f
Show file tree
Hide file tree
Showing 14 changed files with 523 additions and 84 deletions.
41 changes: 39 additions & 2 deletions build_tools/math/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ following requirements:

- Python 3.11 or newer
- mpmath 1.3 or newer
- functional_algorithms 0.9.1 or newer
- functional_algorithms 0.10.1 or newer

that can be installed via pypi:

Expand Down Expand Up @@ -63,7 +63,7 @@ To execute generated tests from a `build` directory, use:
```sh
for t in $(ls ../stablehlo/tests/math/*.mlir); \
do echo $t && ( bin/stablehlo-opt --chlo-legalize-to-stablehlo $t \
| bin/stablehlo-translate --interpret ) ; done
| bin/stablehlo-translate --interpret 2>&1 | grep "^ULP difference" ) ; done
```

When new implementations are generated, one likely needs to update
Expand All @@ -76,3 +76,40 @@ build/bin/stablehlo-opt --chlo-legalize-to-stablehlo --split-input-file --verify
```

and copy relevant checks to `chlo_legalize_to_stablehlo.mlir`.

## A procedure for adding a new algorithm to an existing operation

1. Implement a new algorithm in
[functional_algorithms](https://github.com/pearu/functional_algorithms)
and publish it by creating a new release of
`functional_algorithms`.
2. Build stablehlo on top of its main branch.
3. Update the version requirement of `functional_algorithms` in this
`README.md` and install the latest version of
`functional_algorithms`.
4. Add a record of the operation to `generate_tests.py:operations`
list. Use `size=1000` and `max_ulp_difference=0`.
5. Generate new tests by running `generate_tests.py`.
6. Run the generated tests (see previos section for instructions)
which will output the ULP difference statistics of the current
implementation to stdout; copy this information for
comparision later. Notice that tests failures are expected because of
the specified `max_ulp_difference=0` in the step 4.
7. Add a record of the operation to
`generate_ChloDecompositionPatternsMath.py`, see the for-loop in
`main` function.
8. Generate new implementations by running
`generate_ChloDecompositionPatternsMath.py` and remove existing
implementations in
`stablehlo/transforms/ChloDecompositionPatterns.td` as needed.
9. Re-build stablehlo.
10. Re-run the generated tests and compare the ULP difference statistics
results of the implementation with the one obtained in step 6.
11. If the new implementation improves ULP difference statistics,
prepare a PR for stablehlo. When submitting the PR, don't forget
to apply the following steps:
- remove the specified `max_ulp_difference=0` from
`generate_tests.py` and re-generate tests with
`size=default_size`,
- update `chlo_legalize_to_stablehlo.mlir`, see previos section
for instructions.
7 changes: 7 additions & 0 deletions build_tools/math/generate_ChloDecompositionPatternsMath.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def main():
sources = []
target = fa.targets.stablehlo
for chloname, fname, args in [
# Important: new items to this list must be added to the end,
# otherwise, git diff may end up being unnecessarily large.
#
# (<op CHLO name>, <op name in fa.algorithms>, <a tuple of op arguments>)
#
("CHLO_AsinAcosKernelOp", "asin_acos_kernel", ("z:complex",)),
("CHLO_AsinOp", "complex_asin", ("z:complex",)),
("CHLO_AsinOp", "real_asin", ("x:float",)),
Expand All @@ -89,6 +94,8 @@ def main():
("CHLO_AcoshOp", "real_acosh", ("x:float",)),
("CHLO_AsinhOp", "complex_asinh", ("z:complex",)),
("CHLO_AsinhOp", "real_asinh", ("x:float",)),
("CHLO_AtanOp", "complex_atan", ("z:complex",)),
("CHLO_AtanhOp", "complex_atanh", ("z:complex",)),
]:
func = getattr(fa.algorithms, fname, None)
if func is None:
Expand Down
72 changes: 32 additions & 40 deletions build_tools/math/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,26 @@
default_max_ulp_difference = 1

operations = [
dict(
name="asin",
mpmath_name="arcsin",
size=13,
# TODO(pearu): reduce to 1 after a fix to mpmath/mpmath#787 becomes available
extra_prec_multiplier=20,
max_ulp_difference=3,
),
dict(
name="acos",
mpmath_name="arccos",
size=13,
# TODO(pearu): reduce to 1 after a fix to mpmath/mpmath#787 becomes available
extra_prec_multiplier=20,
max_ulp_difference=3,
),
dict(
name="asinh",
mpmath_name="arcsinh",
size=13,
# TODO(pearu): reduce to 1 after a fix to mpmath/mpmath#787 becomes available
extra_prec_multiplier=20,
max_ulp_difference=3,
),
dict(
name="acosh",
mpmath_name="arccosh",
size=13,
# TODO(pearu): reduce to 1 after a fix to mpmath/mpmath#787 becomes available
extra_prec_multiplier=20,
max_ulp_difference=3,
),
# The following dictionaries may have additional keys like
#
# size - defines the number of samples: size ** 2
#
# max_ulp_difference - the maximal allowed ULP difference between
# function and reference values
#
# extra_prec_multiplier - the precison multiplier for mpmath.mp
# that defines the precision of computing reference values:
# mpmath.mp.prec * extra_prec_multiplier
#
# When unspecifed, these parameters are retrieved from
# functional_algorithms database of support functions.
#
dict(name="asin", mpmath_name="arcsin"),
dict(name="acos", mpmath_name="arccos"),
dict(name="atan", mpmath_name="arctan"),
dict(name="asinh", mpmath_name="arcsinh"),
dict(name="acosh", mpmath_name="arccosh"),
dict(name="atanh", mpmath_name="arctanh"),
]


Expand Down Expand Up @@ -139,16 +127,20 @@ def main():
opname = op["name"]
mpmath_opname = op.get("mpmath_name", opname)
size_re = size_im = op.get("size", default_size)
extra_prec_multiplier = op.get("extra_prec_multiplier",
default_extra_prec_multiplier)
max_ulp_difference = op.get("max_ulp_difference",
default_max_ulp_difference)

nmp = fa.utils.numpy_with_mpmath(
extra_prec_multiplier=extra_prec_multiplier,
flush_subnormals=flush_subnormals,
)
for dtype in [np.complex64, np.complex128, np.float32, np.float64]:
params = fa.utils.function_validation_parameters(opname, dtype)
max_ulp_difference = op.get(
"max_ulp_difference",
params.get("max_valid_ulp_count", default_max_ulp_difference))

nmp = fa.utils.numpy_with_mpmath(
extra_prec_multiplier = op.get(
"extra_prec_multiplier",
params.get("extra_prec_multiplier", default_extra_prec_multiplier)),
flush_subnormals=flush_subnormals,
)

fi = np.finfo(dtype)

float_dtype = to_float_dtype[dtype]
Expand Down
110 changes: 87 additions & 23 deletions stablehlo/tests/chlo/chlo_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2888,37 +2888,101 @@ func.func @cosh_complex_f32(%x : tensor<complex<f32>>) -> tensor<complex<f32>> {

// -----

// CHECK-LABEL: @atanh_f32
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
// CHECK-LABEL: func.func @atanh_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[VAL_1:.*]] = stablehlo.abs %[[VAL_0]] : tensor<f32>
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_3:.*]] = stablehlo.compare GT, %[[VAL_1]], %[[VAL_2]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_4:.*]] = stablehlo.constant dense<0x7FC00000> : tensor<f32>
// CHECK: %[[VAL_5:.*]] = stablehlo.log_plus_one %[[VAL_0]] : tensor<f32>
// CHECK: %[[VAL_6:.*]] = stablehlo.negate %[[VAL_0]] : tensor<f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.log_plus_one %[[VAL_6]] : tensor<f32>
// CHECK: %[[VAL_8:.*]] = stablehlo.subtract %[[VAL_5]], %[[VAL_7]] : tensor<f32>
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.multiply %[[VAL_8]], %[[VAL_9]] : tensor<f32>
// CHECK: %[[VAL_11:.*]] = stablehlo.select %[[VAL_3]], %[[VAL_4]], %[[VAL_10]] : tensor<i1>, tensor<f32>
// CHECK: return %[[VAL_11]] : tensor<f32>
// CHECK: }
func.func @atanh_f32(%arg : tensor<f32>) -> tensor<f32> {
// CHECK-NEXT: %[[TMP_0:.*]] = stablehlo.abs %[[ARG]]
// CHECK-NEXT: %[[TMP_1:.*]] = stablehlo.constant dense<1.000000e+00>
// CHECK-NEXT: %[[TMP_2:.*]] = stablehlo.compare GT, %[[TMP_0]], %[[TMP_1]]
// CHECK-NEXT: %[[TMP_3:.*]] = stablehlo.constant dense<0x7FC00000>
// CHECK-NEXT: %[[TMP_4:.*]] = stablehlo.log_plus_one %[[ARG]]
// CHECK-NEXT: %[[TMP_5:.*]] = stablehlo.negate %[[ARG]]
// CHECK-NEXT: %[[TMP_6:.*]] = stablehlo.log_plus_one %[[TMP_5]]
// CHECK-NEXT: %[[TMP_7:.*]] = stablehlo.subtract %[[TMP_4]], %[[TMP_6]]
// CHECK-NEXT: %[[TMP_8:.*]] = stablehlo.constant dense<5.000000e-01>
// CHECK-NEXT: %[[TMP_9:.*]] = stablehlo.multiply %[[TMP_7]], %[[TMP_8]]
// CHECK-NEXT: %[[TMP_10:.*]] = stablehlo.select %[[TMP_2]], %[[TMP_3]], %[[TMP_9]]
// CHECK-NEXT: return %[[TMP_10]]
%result = "chlo.atanh"(%arg) : (tensor<f32>) -> tensor<f32>
func.return %result : tensor<f32>
}

// -----

// CHECK-LABEL: @atanh_complex_f32
// CHECK-SAME: %[[ARG:.*]]: tensor<complex<f32>>
// CHECK-LABEL: func.func @atanh_complex_f32(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<complex<f32>>) -> tensor<complex<f32>> {
// CHECK: %[[VAL_1:.*]] = stablehlo.real %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_3:.*]] = stablehlo.compare GE, %[[VAL_1]], %[[VAL_2]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_4:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_6:.*]] = stablehlo.select %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: %[[VAL_8:.*]] = stablehlo.abs %[[VAL_1]] : tensor<f32>
// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<3.40282347E+38> : tensor<f32>
// CHECK: %[[VAL_10:.*]] = stablehlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: %[[VAL_11:.*]] = stablehlo.compare GT, %[[VAL_9]], %[[VAL_10]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<9.00719925E+15> : tensor<f32>
// CHECK: %[[VAL_13:.*]] = stablehlo.constant dense<9.99999968E+37> : tensor<f32>
// CHECK: %[[VAL_14:.*]] = stablehlo.compare GT, %[[VAL_9]], %[[VAL_13]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_15:.*]] = stablehlo.constant dense<0x4B800001> : tensor<f32>
// CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<2.050000e+03> : tensor<f32>
// CHECK: %[[VAL_17:.*]] = stablehlo.select %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_18:.*]] = stablehlo.select %[[VAL_11]], %[[VAL_12]], %[[VAL_17]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_19:.*]] = stablehlo.multiply %[[VAL_18]], %[[VAL_18]] : tensor<f32>
// CHECK: %[[VAL_20:.*]] = stablehlo.compare LT, %[[VAL_8]], %[[VAL_19]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_21:.*]] = stablehlo.imag %[[VAL_0]] : (tensor<complex<f32>>) -> tensor<f32>
// CHECK: %[[VAL_22:.*]] = stablehlo.abs %[[VAL_21]] : tensor<f32>
// CHECK: %[[VAL_23:.*]] = stablehlo.compare LT, %[[VAL_22]], %[[VAL_19]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_24:.*]] = stablehlo.and %[[VAL_20]], %[[VAL_23]] : tensor<i1>
// CHECK: %[[VAL_25:.*]] = stablehlo.subtract %[[VAL_4]], %[[VAL_8]] : tensor<f32>
// CHECK: %[[VAL_26:.*]] = stablehlo.multiply %[[VAL_25]], %[[VAL_25]] : tensor<f32>
// CHECK: %[[VAL_27:.*]] = stablehlo.multiply %[[VAL_21]], %[[VAL_21]] : tensor<f32>
// CHECK: %[[VAL_28:.*]] = stablehlo.add %[[VAL_26]], %[[VAL_27]] : tensor<f32>
// CHECK: %[[VAL_29:.*]] = stablehlo.divide %[[VAL_8]], %[[VAL_28]] : tensor<f32>
// CHECK: %[[VAL_30:.*]] = stablehlo.multiply %[[VAL_22]], %[[VAL_18]] : tensor<f32>
// CHECK: %[[VAL_31:.*]] = stablehlo.compare LT, %[[VAL_30]], %[[VAL_8]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_32:.*]] = stablehlo.divide %[[VAL_4]], %[[VAL_8]] : tensor<f32>
// CHECK: %[[VAL_33:.*]] = stablehlo.constant dense<0x7F800000> : tensor<f32>
// CHECK: %[[VAL_34:.*]] = stablehlo.compare EQ, %[[VAL_1]], %[[VAL_33]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_35:.*]] = stablehlo.constant dense<0xFF800000> : tensor<f32>
// CHECK: %[[VAL_36:.*]] = stablehlo.compare EQ, %[[VAL_1]], %[[VAL_35]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_37:.*]] = stablehlo.or %[[VAL_34]], %[[VAL_36]] : tensor<i1>
// CHECK: %[[VAL_38:.*]] = stablehlo.compare EQ, %[[VAL_21]], %[[VAL_33]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_39:.*]] = stablehlo.compare EQ, %[[VAL_21]], %[[VAL_35]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_40:.*]] = stablehlo.or %[[VAL_38]], %[[VAL_39]] : tensor<i1>
// CHECK: %[[VAL_41:.*]] = stablehlo.or %[[VAL_37]], %[[VAL_40]] : tensor<i1>
// CHECK: %[[VAL_42:.*]] = stablehlo.divide %[[VAL_8]], %[[VAL_21]] : tensor<f32>
// CHECK: %[[VAL_43:.*]] = stablehlo.divide %[[VAL_21]], %[[VAL_8]] : tensor<f32>
// CHECK: %[[VAL_44:.*]] = stablehlo.add %[[VAL_42]], %[[VAL_43]] : tensor<f32>
// CHECK: %[[VAL_45:.*]] = stablehlo.divide %[[VAL_4]], %[[VAL_44]] : tensor<f32>
// CHECK: %[[VAL_46:.*]] = stablehlo.divide %[[VAL_45]], %[[VAL_21]] : tensor<f32>
// CHECK: %[[VAL_47:.*]] = stablehlo.select %[[VAL_41]], %[[VAL_2]], %[[VAL_46]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_48:.*]] = stablehlo.select %[[VAL_31]], %[[VAL_32]], %[[VAL_47]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_49:.*]] = stablehlo.select %[[VAL_24]], %[[VAL_29]], %[[VAL_48]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_50:.*]] = stablehlo.multiply %[[VAL_7]], %[[VAL_49]] : tensor<f32>
// CHECK: %[[VAL_51:.*]] = stablehlo.log_plus_one %[[VAL_50]] : tensor<f32>
// CHECK: %[[VAL_52:.*]] = stablehlo.multiply %[[VAL_6]], %[[VAL_51]] : tensor<f32>
// CHECK: %[[VAL_53:.*]] = stablehlo.constant dense<2.500000e-01> : tensor<f32>
// CHECK: %[[VAL_54:.*]] = stablehlo.multiply %[[VAL_52]], %[[VAL_53]] : tensor<f32>
// CHECK: %[[VAL_55:.*]] = stablehlo.add %[[VAL_21]], %[[VAL_21]] : tensor<f32>
// CHECK: %[[VAL_56:.*]] = stablehlo.add %[[VAL_4]], %[[VAL_8]] : tensor<f32>
// CHECK: %[[VAL_57:.*]] = stablehlo.multiply %[[VAL_25]], %[[VAL_56]] : tensor<f32>
// CHECK: %[[VAL_58:.*]] = stablehlo.subtract %[[VAL_57]], %[[VAL_27]] : tensor<f32>
// CHECK: %[[VAL_59:.*]] = stablehlo.atan2 %[[VAL_55]], %[[VAL_58]] : tensor<f32>
// CHECK: %[[VAL_60:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VAL_61:.*]] = stablehlo.compare GE, %[[VAL_21]], %[[VAL_60]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: %[[VAL_62:.*]] = stablehlo.select %[[VAL_61]], %[[VAL_4]], %[[VAL_5]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_63:.*]] = stablehlo.constant dense<3.14159274> : tensor<f32>
// CHECK: %[[VAL_64:.*]] = stablehlo.multiply %[[VAL_62]], %[[VAL_63]] : tensor<f32>
// CHECK: %[[VAL_65:.*]] = stablehlo.select %[[VAL_24]], %[[VAL_59]], %[[VAL_64]] : tensor<i1>, tensor<f32>
// CHECK: %[[VAL_66:.*]] = stablehlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK: %[[VAL_67:.*]] = stablehlo.multiply %[[VAL_65]], %[[VAL_66]] : tensor<f32>
// CHECK: %[[VAL_68:.*]] = stablehlo.complex %[[VAL_54]], %[[VAL_67]] : tensor<complex<f32>>
// CHECK: return %[[VAL_68]] : tensor<complex<f32>>
// CHECK: }
func.func @atanh_complex_f32(%arg : tensor<complex<f32>>) -> tensor<complex<f32>> {
// CHECK-NEXT: %[[TMP_0:.*]] = stablehlo.log_plus_one %[[ARG]]
// CHECK-NEXT: %[[TMP_1:.*]] = stablehlo.negate %[[ARG]]
// CHECK-NEXT: %[[TMP_2:.*]] = stablehlo.log_plus_one %[[TMP_1]]
// CHECK-NEXT: %[[TMP_3:.*]] = stablehlo.subtract %[[TMP_0]], %[[TMP_2]]
// CHECK-NEXT: %[[TMP_4:.*]] = stablehlo.constant dense<(5.000000e-01,0.000000e+00)>
// CHECK-NEXT: %[[TMP_5:.*]] = stablehlo.multiply %[[TMP_3]], %[[TMP_4]]
// CHECK-NEXT: return %[[TMP_5]]
%result = "chlo.atanh"(%arg) : (tensor<complex<f32>>) -> tensor<complex<f32>>
func.return %result : tensor<complex<f32>>
}
Expand Down
Loading

0 comments on commit 932b32f

Please sign in to comment.