Skip to content

Commit

Permalink
Fixed-point for QBDD (unitaryfund#1013)
Browse files Browse the repository at this point in the history
* Fixed-point QBDT (DOES NOT LINK)

* Fewer linker errors

* Links, but fails tests

* Remove macro for ENABLE_FIXED_POINT

* Fix SIGFPE

* QBDD tree.cpp should use floating-point
  • Loading branch information
WrathfulSpatula authored Jun 26, 2024
1 parent c93f14a commit b25c39b
Show file tree
Hide file tree
Showing 13 changed files with 402 additions and 539 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ endif (PACK_DEBIAN)
add_library (qrack STATIC
src/common/functions.cpp
src/common/parallel_for.cpp
src/common/fixed.cpp
src/qinterface/arithmetic.cpp
src/qinterface/gates.cpp
src/qinterface/logic.cpp
Expand Down Expand Up @@ -361,6 +362,7 @@ install (FILES
include/common/dispatchqueue.hpp
include/common/big_integer.hpp
include/common/half.hpp
include/common/fixed.hpp
include/common/qneuron_activation_function.hpp
include/common/pauli.hpp
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/qrack/common
Expand Down
1 change: 0 additions & 1 deletion cmake/FpMath.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
set(FPPOW "5" CACHE STRING "Log2 of float bits, for use in pairs as complex amplitudes (must be at least 2, equivalent to half precision)")
option(ENABLE_FIXED_POINT "Use fixed-point math instead of floating-point" OFF)

if (FPPOW LESS 4)
message(FATAL_ERROR "FPPOW must be at least 4, equivalent to \"half\" precision!")
Expand Down
200 changes: 31 additions & 169 deletions include/common/fixed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,230 +459,92 @@ template <size_t I, size_t F> class fixed {
// if we have the same fractional portion, but differing integer portions, we trivially upgrade the smaller type
template <size_t I1, size_t I2, size_t F>
CONSTEXPR14 typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type operator+(
fixed<I1, F> lhs, fixed<I2, F> rhs)
{

using T = typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type;

const T l = T::from_base(lhs.to_raw());
const T r = T::from_base(rhs.to_raw());
return l + r;
}
fixed<I1, F> lhs, fixed<I2, F> rhs);

template <size_t I1, size_t I2, size_t F>
CONSTEXPR14 typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type operator-(
fixed<I1, F> lhs, fixed<I2, F> rhs)
{

using T = typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type;

const T l = T::from_base(lhs.to_raw());
const T r = T::from_base(rhs.to_raw());
return l - r;
}
fixed<I1, F> lhs, fixed<I2, F> rhs);

template <size_t I1, size_t I2, size_t F>
CONSTEXPR14 typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type operator*(
fixed<I1, F> lhs, fixed<I2, F> rhs)
{

using T = typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type;

const T l = T::from_base(lhs.to_raw());
const T r = T::from_base(rhs.to_raw());
return l * r;
}
fixed<I1, F> lhs, fixed<I2, F> rhs);

template <size_t I1, size_t I2, size_t F>
CONSTEXPR14 typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type operator/(
fixed<I1, F> lhs, fixed<I2, F> rhs)
{
fixed<I1, F> lhs, fixed<I2, F> rhs);

using T = typename std::conditional<I1 >= I2, fixed<I1, F>, fixed<I2, F>>::type;

const T l = T::from_base(lhs.to_raw());
const T r = T::from_base(rhs.to_raw());
return l / r;
}

template <size_t I, size_t F> std::ostream& operator<<(std::ostream& os, fixed<I, F> f)
{
os << f.to_double();
return os;
}
template <size_t I, size_t F> std::ostream& operator<<(std::ostream& os, fixed<I, F> f);

// basic math operators
template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> operator+(fixed<I, F> lhs, fixed<I, F> rhs)
template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> inline operator+(fixed<I, F> lhs, fixed<I, F> rhs)
{
lhs += rhs;
return lhs;
}

template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> operator-(fixed<I, F> lhs, fixed<I, F> rhs)
template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> inline operator-(fixed<I, F> lhs, fixed<I, F> rhs)
{
lhs -= rhs;
return lhs;
}

template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> operator*(fixed<I, F> lhs, fixed<I, F> rhs)
template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> inline operator*(fixed<I, F> lhs, fixed<I, F> rhs)
{
lhs *= rhs;
return lhs;
}

template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> operator/(fixed<I, F> lhs, fixed<I, F> rhs)
template <size_t I, size_t F> CONSTEXPR14 fixed<I, F> inline operator/(fixed<I, F> lhs, fixed<I, F> rhs)
{
lhs /= rhs;
return lhs;
}

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator+(fixed<I, F> lhs, Number rhs)
{
lhs += fixed<I, F>(rhs);
return lhs;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator+(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator-(fixed<I, F> lhs, Number rhs)
{
lhs -= fixed<I, F>(rhs);
return lhs;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator-(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator*(fixed<I, F> lhs, Number rhs)
{
lhs *= fixed<I, F>(rhs);
return lhs;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator*(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator/(fixed<I, F> lhs, Number rhs)
{
lhs /= fixed<I, F>(rhs);
return lhs;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator/(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator+(Number lhs, fixed<I, F> rhs)
{
fixed<I, F> tmp(lhs);
tmp += rhs;
return tmp;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator+(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator-(Number lhs, fixed<I, F> rhs)
{
fixed<I, F> tmp(lhs);
tmp -= rhs;
return tmp;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator-(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator*(Number lhs, fixed<I, F> rhs)
{
fixed<I, F> tmp(lhs);
tmp *= rhs;
return tmp;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator*(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
CONSTEXPR14 fixed<I, F> operator/(Number lhs, fixed<I, F> rhs)
{
fixed<I, F> tmp(lhs);
tmp /= rhs;
return tmp;
}
template <size_t I, size_t F, class Number, class> CONSTEXPR14 fixed<I, F> operator/(Number lhs, fixed<I, F> rhs);

// shift operators
template <size_t I, size_t F, class Integer, class = typename std::enable_if<std::is_integral<Integer>::value>::type>
CONSTEXPR14 fixed<I, F> operator<<(fixed<I, F> lhs, Integer rhs)
{
lhs <<= rhs;
return lhs;
}
template <size_t I, size_t F, class Integer, class> CONSTEXPR14 fixed<I, F> operator<<(fixed<I, F> lhs, Integer rhs);

template <size_t I, size_t F, class Integer, class = typename std::enable_if<std::is_integral<Integer>::value>::type>
CONSTEXPR14 fixed<I, F> operator>>(fixed<I, F> lhs, Integer rhs)
{
lhs >>= rhs;
return lhs;
}
template <size_t I, size_t F, class Integer, class> CONSTEXPR14 fixed<I, F> operator>>(fixed<I, F> lhs, Integer rhs);

// comparison operators
template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator>(fixed<I, F> lhs, Number rhs)
{
return lhs > fixed<I, F>(rhs);
}
template <size_t I, size_t F, class Number, class> constexpr bool operator>(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator<(fixed<I, F> lhs, Number rhs)
{
return lhs < fixed<I, F>(rhs);
}
template <size_t I, size_t F, class Number, class> constexpr bool operator<(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator>=(fixed<I, F> lhs, Number rhs)
{
return lhs >= fixed<I, F>(rhs);
}
template <size_t I, size_t F, class Number, class> constexpr bool operator>=(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator<=(fixed<I, F> lhs, Number rhs)
{
return lhs <= fixed<I, F>(rhs);
}
template <size_t I, size_t F, class Number, class> constexpr bool operator<=(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator==(fixed<I, F> lhs, Number rhs)
{
return lhs == fixed<I, F>(rhs);
}
template <size_t I, size_t F, class Number, class> constexpr bool operator==(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator!=(fixed<I, F> lhs, Number rhs)
{
return lhs != fixed<I, F>(rhs);
}
template <size_t I, size_t F, class Number, class> constexpr bool operator!=(fixed<I, F> lhs, Number rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator>(Number lhs, fixed<I, F> rhs)
{
return fixed<I, F>(lhs) > rhs;
}
template <size_t I, size_t F, class Number, class> constexpr bool operator>(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator<(Number lhs, fixed<I, F> rhs)
{
return fixed<I, F>(lhs) < rhs;
}
template <size_t I, size_t F, class Number, class> constexpr bool operator<(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator>=(Number lhs, fixed<I, F> rhs)
{
return fixed<I, F>(lhs) >= rhs;
}
template <size_t I, size_t F, class Number, class> constexpr bool operator>=(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator<=(Number lhs, fixed<I, F> rhs)
{
return fixed<I, F>(lhs) <= rhs;
}
template <size_t I, size_t F, class Number, class> constexpr bool operator<=(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator==(Number lhs, fixed<I, F> rhs)
{
return fixed<I, F>(lhs) == rhs;
}
template <size_t I, size_t F, class Number, class> constexpr bool operator==(Number lhs, fixed<I, F> rhs);

template <size_t I, size_t F, class Number, class = typename std::enable_if<std::is_arithmetic<Number>::value>::type>
constexpr bool operator!=(Number lhs, fixed<I, F> rhs)
{
return fixed<I, F>(lhs) != rhs;
}
template <size_t I, size_t F, class Number, class> constexpr bool operator!=(Number lhs, fixed<I, F> rhs);

} // namespace numeric

Expand Down
1 change: 1 addition & 0 deletions include/common/qrack_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ bool isOverflowSub(
bitCapInt pushApartBits(const bitCapInt& perm, const std::vector<bitCapInt>& skipPowers);
bitCapInt intPow(bitCapInt base, bitCapInt power);
bitCapIntOcl intPowOcl(bitCapIntOcl base, bitCapIntOcl power);
complex complexFixedToFloating(complex_x f);

#if QBCAPPOW > 6
std::ostream& operator<<(std::ostream& os, bitCapInt b);
Expand Down
41 changes: 18 additions & 23 deletions include/common/qrack_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,8 @@
#define bitCapInt BigInteger
#endif

#if ENABLE_FIXED_POINT
#include "fixed.hpp"
namespace Qrack {
// We want to be able to represent at least 2 * PI_R1.
// (This is <7.) 1 bit is sign.
// 3 bits represent 1, 2, and 4, for a maximum value of 7
// on the left side of the decimal point.
typedef numeric::Fixed<4U, (1U << FPPOW) - 4U> real1;
#if FPPOW < 6
typedef float real1_f;
typedef float real1_s;
#elif FPPOW < 7
typedef double real1_f;
typedef double real1_s;
#else
typedef boost::multiprecision::float128 real1_f;
typedef double real1_s;
#endif
#else

#if FPPOW < 5
#ifdef __arm__
namespace Qrack {
Expand Down Expand Up @@ -133,13 +116,25 @@ typedef boost::multiprecision::float128 real1_f;
typedef double real1_s;
#endif
#endif
#endif

typedef std::complex<real1> complex;
const bitCapInt ONE_BCI = 1U;
const bitCapInt ZERO_BCI = 0U;
constexpr bitLenInt bitsInCap = ((bitLenInt)1U) << ((bitLenInt)QBCAPPOW);

// We want to be able to represent at least 1
// (and no less, for maximum capacity).
// 1 bit is +/- sign.
// 1 bit is 0/1 on the left side of the decimal point.
typedef numeric::fixed<3U, (1U << FPPOW) - 3U> real1_x;
typedef std::complex<real1_x> complex_x;
constexpr real1_x ONE_R1_X = 1.0f;
constexpr real1_x ZERO_R1_X = 0.0f;
constexpr complex_x ONE_CMPLX_X = complex_x(ONE_R1_X, ZERO_R1_X);
constexpr complex_x ZERO_CMPLX_X = complex_x(ZERO_R1_X, ZERO_R1_X);
constexpr complex_x I_CMPLX_X = complex_x(ZERO_R1_X, ONE_R1_X);
const real1_x SQRT1_2_R1_X = (real1_x)M_SQRT1_2;

typedef std::shared_ptr<complex> BitOp;

// Called once per value between begin and end.
Expand Down Expand Up @@ -172,7 +167,7 @@ const real1 ZERO_R1 = (real1)0.0f;
constexpr real1_f ZERO_R1_F = 0.0f;
const real1 ONE_R1 = (real1)1.0f;
constexpr real1_f ONE_R1_F = 1.0f;
const real1 REAL1_DEFAULT_ARG = (real1)-999.0f;
const real1 REAL1_DEFAULT_ARG = (real1)-7.77f;
// Half the probability in any single permutation of 20 maximally superposed qubits
const real1 REAL1_EPSILON = (real1)0.000000477f;
const real1 PI_R1 = (real1)M_PI;
Expand All @@ -187,7 +182,7 @@ const real1 SQRT1_2_R1 = (real1)M_SQRT1_2;
constexpr real1 PI_R1 = (real1)M_PI;
constexpr real1 SQRT2_R1 = (real1)M_SQRT2;
constexpr real1 SQRT1_2_R1 = (real1)M_SQRT1_2;
#define REAL1_DEFAULT_ARG -999.0f
#define REAL1_DEFAULT_ARG -7.77f
// Half the probability in any single permutation of 48 maximally superposed qubits
#define REAL1_EPSILON 1.7763568394002505e-15f
#elif FPPOW < 7
Expand All @@ -199,7 +194,7 @@ constexpr real1 SQRT1_2_R1 = (real1)M_SQRT1_2;
#define PI_R1 M_PI
#define SQRT2_R1 M_SQRT2
#define SQRT1_2_R1 M_SQRT1_2
#define REAL1_DEFAULT_ARG -999.0
#define REAL1_DEFAULT_ARG -7.77
// Half the probability in any single permutation of 96 maximally superposed qubits
#define REAL1_EPSILON 6.310887241768095e-30
#else
Expand All @@ -211,7 +206,7 @@ constexpr real1 ONE_R1 = (real1)1.0;
constexpr real1_f PI_R1 = (real1_f)M_PI;
constexpr real1_f SQRT2_R1 = (real1_f)M_SQRT2;
constexpr real1_f SQRT1_2_R1 = (real1_f)M_SQRT1_2;
#define REAL1_DEFAULT_ARG -999.0
#define REAL1_DEFAULT_ARG -7.77
// Half the probability in any single permutation of 192 maximally superposed qubits
#define REAL1_EPSILON 7.965459555662261e-59
#endif
Expand Down
4 changes: 2 additions & 2 deletions include/qbdt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class QBdt : public QParity, public QInterface {

_par_for(maxQPower, [&](const bitCapInt& i, const unsigned& cpu) {
QBdtNodeInterfacePtr leaf = root;
complex scale = leaf->scale;
complex_x scale = leaf->scale;
for (bitLenInt j = 0U; j < qubitCount; ++j) {
leaf = leaf->branches[SelectBit(i, j)];
if (!leaf) {
Expand All @@ -105,7 +105,7 @@ class QBdt : public QParity, public QInterface {
scale *= leaf->scale;
}

getLambda((bitCapIntOcl)i, scale);
getLambda((bitCapIntOcl)i, complex((real1)(real(scale).to_double()), (real1)(imag(scale).to_double())));
});
}
template <typename Fn> void SetTraversal(Fn setLambda)
Expand Down
Loading

0 comments on commit b25c39b

Please sign in to comment.