diff --git a/notebooks/binom_factor_table.ipynb b/notebooks/binom_factor_table.ipynb index 6ee1603..4e095dc 100644 --- a/notebooks/binom_factor_table.ipynb +++ b/notebooks/binom_factor_table.ipynb @@ -47,7 +47,7 @@ } ], "source": [ - "LMAX = 4\n", + "LMAX = 8\n", "kilobytes = LMAX ** 5 * 4 / 1024\n", "print(f\"{kilobytes=}\")" ] @@ -75,65 +75,79 @@ "name": "stdout", "output_type": "stream", "text": [ - "(0, 0, 0) 1\n", - "(0, 0, 1) 0\n", - "(0, 0, 2) 0\n", - "(0, 0, 3) 0\n", - "(0, 0, 4) 0\n", - "(0, 1, 0) b\n", - "(0, 1, 1) 1\n", - "(0, 1, 2) 0\n", - "(0, 1, 3) 0\n", - "(0, 1, 4) 0\n", - "(0, 2, 0) b**2\n", - "(0, 2, 1) 2*b\n", - "(0, 2, 2) 1\n", - "(0, 2, 3) 0\n", - "(0, 2, 4) 0\n", - "(1, 0, 0) a\n", - "(1, 0, 1) 1\n", - "(1, 0, 2) 0\n", - "(1, 0, 3) 0\n", - "(1, 0, 4) 0\n", - "(1, 1, 0) a*b\n", - "(1, 1, 1) a + b\n", - "(1, 1, 2) 1\n", - "(1, 1, 3) 0\n", - "(1, 1, 4) 0\n", - "(1, 2, 0) a*b**2\n", - "(1, 2, 1) 2*a*b + b**2\n", - "(1, 2, 2) a + 2*b\n", - "(1, 2, 3) 1\n", - "(1, 2, 4) 0\n", - "(2, 0, 0) a**2\n", - "(2, 0, 1) 2*a\n", - "(2, 0, 2) 1\n", - "(2, 0, 3) 0\n", - "(2, 0, 4) 0\n", - "(2, 1, 0) a**2*b\n", - "(2, 1, 1) a**2 + 2*a*b\n", - "(2, 1, 2) 2*a + b\n", - "(2, 1, 3) 1\n", - "(2, 1, 4) 0\n", - "(2, 2, 0) a**2*b**2\n", - "(2, 2, 1) 2*a**2*b + 2*a*b**2\n", - "(2, 2, 2) a**2 + 4*a*b + b**2\n", - "(2, 2, 3) 2*a + 2*b\n", - "(2, 2, 4) 1\n" + "(0, 0, 0) Poly(1, a, b, domain='ZZ')\n", + "(0, 0, 2) Poly(0, a, b, domain='ZZ')\n", + "(0, 0, 4) Poly(0, a, b, domain='ZZ')\n", + "(0, 0, 6) Poly(0, a, b, domain='ZZ')\n", + "(0, 2, 0) Poly(b**2, a, b, domain='ZZ')\n", + "(0, 2, 2) Poly(1, a, b, domain='ZZ')\n", + "(0, 2, 4) Poly(0, a, b, domain='ZZ')\n", + "(0, 2, 6) Poly(0, a, b, domain='ZZ')\n", + "(0, 4, 0) Poly(b**4, a, b, domain='ZZ')\n", + "(0, 4, 2) Poly(6*b**2, a, b, domain='ZZ')\n", + "(0, 4, 4) Poly(1, a, b, domain='ZZ')\n", + "(0, 4, 6) Poly(0, a, b, domain='ZZ')\n", + "(0, 6, 0) Poly(b**6, a, b, domain='ZZ')\n", + "(0, 6, 2) Poly(15*b**4, a, b, domain='ZZ')\n", + "(0, 6, 4) Poly(15*b**2, a, b, domain='ZZ')\n", + "(0, 6, 6) Poly(1, a, b, domain='ZZ')\n", + "(3, 0, 0) Poly(a**3, a, b, domain='ZZ')\n", + "(3, 0, 2) Poly(3*a, a, b, domain='ZZ')\n", + "(3, 0, 4) Poly(0, a, b, domain='ZZ')\n", + "(3, 0, 6) Poly(0, a, b, domain='ZZ')\n", + "(3, 2, 0) Poly(a**3*b**2, a, b, domain='ZZ')\n", + "(3, 2, 2) Poly(a**3 + 6*a**2*b + 3*a*b**2, a, b, domain='ZZ')\n", + "(3, 2, 4) Poly(3*a + 2*b, a, b, domain='ZZ')\n", + "(3, 2, 6) Poly(0, a, b, domain='ZZ')\n", + "(3, 4, 0) Poly(a**3*b**4, a, b, domain='ZZ')\n", + "(3, 4, 2) Poly(6*a**3*b**2 + 12*a**2*b**3 + 3*a*b**4, a, b, domain='ZZ')\n", + "(3, 4, 4) Poly(a**3 + 12*a**2*b + 18*a*b**2 + 4*b**3, a, b, domain='ZZ')\n", + "(3, 4, 6) Poly(3*a + 4*b, a, b, domain='ZZ')\n", + "(3, 6, 0) Poly(a**3*b**6, a, b, domain='ZZ')\n", + "(3, 6, 2) Poly(15*a**3*b**4 + 18*a**2*b**5 + 3*a*b**6, a, b, domain='ZZ')\n", + "(3, 6, 4) Poly(15*a**3*b**2 + 60*a**2*b**3 + 45*a*b**4 + 6*b**5, a, b, domain='ZZ')\n", + "(3, 6, 6) Poly(a**3 + 18*a**2*b + 45*a*b**2 + 20*b**3, a, b, domain='ZZ')\n", + "(6, 0, 0) Poly(a**6, a, b, domain='ZZ')\n", + "(6, 0, 2) Poly(15*a**4, a, b, domain='ZZ')\n", + "(6, 0, 4) Poly(15*a**2, a, b, domain='ZZ')\n", + "(6, 0, 6) Poly(1, a, b, domain='ZZ')\n", + "(6, 2, 0) Poly(a**6*b**2, a, b, domain='ZZ')\n", + "(6, 2, 2) Poly(a**6 + 12*a**5*b + 15*a**4*b**2, a, b, domain='ZZ')\n", + "(6, 2, 4) Poly(15*a**4 + 40*a**3*b + 15*a**2*b**2, a, b, domain='ZZ')\n", + "(6, 2, 6) Poly(15*a**2 + 12*a*b + b**2, a, b, domain='ZZ')\n", + "(6, 4, 0) Poly(a**6*b**4, a, b, domain='ZZ')\n", + "(6, 4, 2) Poly(6*a**6*b**2 + 24*a**5*b**3 + 15*a**4*b**4, a, b, domain='ZZ')\n", + "(6, 4, 4) Poly(a**6 + 24*a**5*b + 90*a**4*b**2 + 80*a**3*b**3 + 15*a**2*b**4, a, b, domain='ZZ')\n", + "(6, 4, 6) Poly(15*a**4 + 80*a**3*b + 90*a**2*b**2 + 24*a*b**3 + b**4, a, b, domain='ZZ')\n", + "(6, 6, 0) Poly(a**6*b**6, a, b, domain='ZZ')\n", + "(6, 6, 2) Poly(15*a**6*b**4 + 36*a**5*b**5 + 15*a**4*b**6, a, b, domain='ZZ')\n", + "(6, 6, 4) Poly(15*a**6*b**2 + 120*a**5*b**3 + 225*a**4*b**4 + 120*a**3*b**5 + 15*a**2*b**6, a, b, domain='ZZ')\n", + "(6, 6, 6) Poly(a**6 + 36*a**5*b + 225*a**4*b**2 + 400*a**3*b**3 + 225*a**2*b**4 + 36*a*b**5 + b**6, a, b, domain='ZZ')\n" ] }, { "data": { "text/latex": [ - "$\\displaystyle 1$" + "$\\displaystyle \\operatorname{Poly}{\\left( a^{6} + 36 a^{5}b + 225 a^{4}b^{2} + 400 a^{3}b^{3} + 225 a^{2}b^{4} + 36 ab^{5} + b^{6}, a, b, domain=\\mathbb{Z} \\right)}$" ], "text/plain": [ - "1" + "Poly(a**6 + 36*a**5*b + 225*a**4*b**2 + 400*a**3*b**3 + 225*a**2*b**4 + 36*a*b**5 + b**6, a, b, domain='ZZ')" ] }, - "execution_count": 3, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$\\displaystyle 400$" + ], + "text/plain": [ + "400" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -150,15 +164,18 @@ " out += val\n", " return out\n", "\n", + "def binom_factor_sym(i: int, j: int, a: Symbol, b: Symbol, s: int):\n", + " return Poly(binom_factor(i, j, a, b, s), a, b)\n", + "\n", "\n", "a, b = symbols(\"a b\", real=True)\n", - "for i in range(3):\n", - " for j in range(3):\n", - " for s in range(5):\n", - " bf = binom_factor(i, j, a, b, s)\n", + "for i in range(0,7,3):\n", + " for j in range(0,7,2):\n", + " for s in range(0,7,2):\n", + " bf = binom_factor_sym(i, j, a, b, s)\n", " print((i, j, s), bf)\n", "\n", - "bf" + "display(bf, bf.coeff_monomial(a**3*b**3))" ] }, { @@ -169,52 +186,15 @@ { "data": { "text/latex": [ - "$\\displaystyle \\left\\{ 1 : 4, \\ a : 1, \\ a b : 3\\right\\}$" - ], - "text/plain": [ - "{1: 4, a: 1, a*b: 3}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle b^{2}$" - ], - "text/plain": [ - "b**2" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{b**2: 1}" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle 1$" + "$\\displaystyle \\left[\\begin{matrix}1 & b & b^{2} & b^{3} & b^{4}\\\\a & a b & a b^{2} & a b^{3} & a b^{4}\\\\a^{2} & a^{2} b & a^{2} b^{2} & a^{2} b^{3} & a^{2} b^{4}\\\\a^{3} & a^{3} b & a^{3} b^{2} & a^{3} b^{3} & a^{3} b^{4}\\\\a^{4} & a^{4} b & a^{4} b^{2} & a^{4} b^{3} & a^{4} b^{4}\\end{matrix}\\right]$" ], "text/plain": [ - "1" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "{1: 1.00000000000000}" + "Matrix([\n", + "[ 1, b, b**2, b**3, b**4],\n", + "[ a, a*b, a*b**2, a*b**3, a*b**4],\n", + "[a**2, a**2*b, a**2*b**2, a**2*b**3, a**2*b**4],\n", + "[a**3, a**3*b, a**3*b**2, a**3*b**3, a**3*b**4],\n", + "[a**4, a**4*b, a**4*b**2, a**4*b**3, a**4*b**4]])" ] }, "metadata": {}, @@ -223,10 +203,10 @@ { "data": { "text/latex": [ - "$\\displaystyle 2 a b + b^{2}$" + "$\\displaystyle \\left[\\begin{array}{ccccccccccccccccccccccccc}1 & b & b^{2} & b^{3} & b^{4} & a & a b & a b^{2} & a b^{3} & a b^{4} & a^{2} & a^{2} b & a^{2} b^{2} & a^{2} b^{3} & a^{2} b^{4} & a^{3} & a^{3} b & a^{3} b^{2} & a^{3} b^{3} & a^{3} b^{4} & a^{4} & a^{4} b & a^{4} b^{2} & a^{4} b^{3} & a^{4} b^{4}\\end{array}\\right]$" ], "text/plain": [ - "2*a*b + b**2" + "Matrix([[1, b, b**2, b**3, b**4, a, a*b, a*b**2, a*b**3, a*b**4, a**2, a**2*b, a**2*b**2, a**2*b**3, a**2*b**4, a**3, a**3*b, a**3*b**2, a**3*b**3, a**3*b**4, a**4, a**4*b, a**4*b**2, a**4*b**3, a**4*b**4]])" ] }, "metadata": {}, @@ -234,61 +214,14 @@ } ], "source": [ - "# https://stackoverflow.com/questions/74731353/is-there-any-all-coeffs-for-multivariable-polynomials-in-sympy\n", - "def all_coeffs(expr, *free):\n", - " if isinstance(expr, (int, Number)):\n", - " return {S(1): N(expr)}\n", - "\n", - " x = IndexedBase(\"x\")\n", - " expr = expr.expand()\n", - " free = list(free) or list(expr.free_symbols)\n", - " pows = [p.as_base_exp() for p in expr.atoms(Pow, Symbol)]\n", - " P = {}\n", - " for p, e in pows:\n", - " if p not in free:\n", - " continue\n", - " elif p not in P:\n", - " P[p] = e\n", - " elif e > P[p]:\n", - " P[p] = e\n", - " reps = dict([(f, x[i]) for i, f in enumerate(free)])\n", - " xzero = dict([(v, 0) for k, v in reps.items()])\n", - " e = expr.xreplace(reps)\n", - " reps = {v: k for k, v in reps.items()}\n", - " ans = dict(\n", - " [\n", - " (\n", - " m.xreplace(reps),\n", - " e.coeff(m).xreplace(xzero) if m != 1 else e.xreplace(xzero),\n", - " )\n", - " for m in monomials(*[P[f] for f in free])\n", - " ]\n", - " )\n", - " return {m: w for m, w in ans.items() if w != 0}\n", - "\n", - "\n", - "def monomials(*o):\n", - " x = IndexedBase(\"x\")\n", - " f = []\n", - " for i, o in enumerate(o):\n", - " f.append(Poly([1] * (o + 1), x[i]).as_expr())\n", - " return Mul(*f).expand().args\n", - "\n", + "monomials_a = Matrix([a ** i for i in range(LMAX)])\n", + "monomials_b = Matrix([b ** i for i in range(LMAX)])\n", "\n", - "assert all_coeffs(4 + a + 3 * b * a) == {1: 4, a: 1, a * b: 3}\n", - "display(S(all_coeffs(4 + a + 3 * b * a)))\n", + "all_monomials = monomials_a * monomials_b.transpose()\n", + "display(all_monomials)\n", + "all_monomials = all_monomials.reshape(1,LMAX**2)\n", "\n", - "bf = binom_factor(0, 2, a, b, 0)\n", - "display(bf, all_coeffs(bf))\n", - "assert all_coeffs(b * b) == {b * b: 1}\n", - "\n", - "bf = binom_factor(0, 0, a, b, 0)\n", - "assert bf == 1\n", - "display(bf, all_coeffs(bf))\n", - "assert all_coeffs(bf) == {1: 1}\n", - "\n", - "bf = binom_factor(1, 2, a, b, 1)\n", - "display(bf)\n" + "display(all_monomials)\n" ] }, { @@ -298,272 +231,103 @@ "outputs": [ { "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 1, \\ 0\\right), \\ a b, \\ \\left\\{ a b : 1\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 1, 0), a*b, {a*b: 1})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 1, \\ 1\\right), \\ a + b, \\ \\left\\{ a : 1, \\ b : 1\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 1, 1), a + b, {a: 1, b: 1})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 1, \\ 2\\right), \\ 1, \\ \\left\\{ 1 : 1.0\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 1, 2), 1, {1: 1.0})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 3, \\ 0\\right), \\ a b^{3}, \\ \\left\\{ a b^{3} : 1\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 3, 0), a*b**3, {a*b**3: 1})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 3, \\ 1\\right), \\ 3 a b^{2} + b^{3}, \\ \\left\\{ b^{3} : 1, \\ a b^{2} : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 3, 1), 3*a*b**2 + b**3, {b**3: 1, a*b**2: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 3, \\ 2\\right), \\ 3 a b + 3 b^{2}, \\ \\left\\{ b^{2} : 3, \\ a b : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 3, 2), 3*a*b + 3*b**2, {b**2: 3, a*b: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 3, \\ 3\\right), \\ a + 3 b, \\ \\left\\{ a : 1, \\ b : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 3, 3), a + 3*b, {a: 1, b: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 1, \\ 3, \\ 4\\right), \\ 1, \\ \\left\\{ 1 : 1.0\\right\\}\\right)$" - ], - "text/plain": [ - "((1, 3, 4), 1, {1: 1.0})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 1, \\ 0\\right), \\ a^{3} b, \\ \\left\\{ a^{3} b : 1\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 1, 0), a**3*b, {a**3*b: 1})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 1, \\ 1\\right), \\ a^{3} + 3 a^{2} b, \\ \\left\\{ a^{3} : 1, \\ a^{2} b : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 1, 1), a**3 + 3*a**2*b, {a**3: 1, a**2*b: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 1, \\ 2\\right), \\ 3 a^{2} + 3 a b, \\ \\left\\{ a^{2} : 3, \\ a b : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 1, 2), 3*a**2 + 3*a*b, {a**2: 3, a*b: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 1, \\ 3\\right), \\ 3 a + b, \\ \\left\\{ a : 3, \\ b : 1\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 1, 3), 3*a + b, {a: 3, b: 1})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 1, \\ 4\\right), \\ 1, \\ \\left\\{ 1 : 1.0\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 1, 4), 1, {1: 1.0})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 3, \\ 0\\right), \\ a^{3} b^{3}, \\ \\left\\{ a^{3} b^{3} : 1\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 3, 0), a**3*b**3, {a**3*b**3: 1})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 3, \\ 1\\right), \\ 3 a^{3} b^{2} + 3 a^{2} b^{3}, \\ \\left\\{ a^{2} b^{3} : 3, \\ a^{3} b^{2} : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 3, 1), 3*a**3*b**2 + 3*a**2*b**3, {a**2*b**3: 3, a**3*b**2: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 3, \\ 2\\right), \\ 3 a^{3} b + 9 a^{2} b^{2} + 3 a b^{3}, \\ \\left\\{ a b^{3} : 3, \\ a^{2} b^{2} : 9, \\ a^{3} b : 3\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 3, 2), 3*a**3*b + 9*a**2*b**2 + 3*a*b**3, {a*b**3: 3, a**2*b**2: 9, a**3*b: 3})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 3, \\ 3\\right), \\ a^{3} + 9 a^{2} b + 9 a b^{2} + b^{3}, \\ \\left\\{ a^{3} : 1, \\ b^{3} : 1, \\ a b^{2} : 9, \\ a^{2} b : 9\\right\\}\\right)$" - ], - "text/plain": [ - "((3, 3, 3), a**3 + 9*a**2*b + 9*a*b**2 + b**3, {a**3: 1, b**3: 1, a*b**2: 9, a**2*b: 9})" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( \\left( 3, \\ 3, \\ 4\\right), \\ 3 a^{2} + 9 a b + 3 b^{2}, \\ \\left\\{ a^{2} : 3, \\ b^{2} : 3, \\ a b : 9\\right\\}\\right)$" - ], "text/plain": [ - "((3, 3, 4), 3*a**2 + 9*a*b + 3*b**2, {a**2: 3, b**2: 3, a*b: 9})" + "array([[0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 225],\n", + " [0, 0, 0, 400, 0],\n", + " [0, 0, 225, 0, 0]], dtype=object)" ] }, + "execution_count": 5, "metadata": {}, - "output_type": "display_data" - }, + "output_type": "execute_result" + } + ], + "source": [ + "def get_coeffs(p):\n", + " return tuple(p.coeff_monomial(m) for m in all_monomials)\n", + "\n", + "np.array(get_coeffs(binom_factor_sym(6,6,a,b,6))).reshape(LMAX,LMAX)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "len(all_monomials)=25\n" + "(0, 1, 2) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(0, 1, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(0, 3, 3) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", + "(1, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(2, 1, 3) (1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(1, a, b, domain='ZZ')\n", + "(3, 0, 4) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(0, a, b, domain='ZZ')\n", + "(3, 1, 1) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0) Poly(a**3 + 3*a**2*b, a, b, domain='ZZ')\n", + "(3, 4, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0) Poly(a**3*b**4, a, b, domain='ZZ')\n", + "(4, 2, 0) (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0) Poly(a**4*b**2, a, b, domain='ZZ')\n" ] - }, - { - "data": { - "text/latex": [ - "$\\displaystyle \\left( 1, \\ a^{2} b, \\ a^{4}, \\ a^{2} b^{4}, \\ b^{4}, \\ a^{3} b^{2}, \\ a^{4} b^{2}, \\ a b^{4}, \\ a^{3} b, \\ a^{3} b^{3}, \\ a^{2}, \\ a b, \\ a^{4} b^{3}, \\ b^{2}, \\ a^{4} b, \\ a, \\ a^{3} b^{4}, \\ a^{4} b^{4}, \\ a^{2} b^{2}, \\ a b^{2}, \\ b, \\ b^{3}, \\ a^{3}, \\ a^{2} b^{3}, \\ a b^{3}\\right)$" - ], - "text/plain": [ - "(1, a**2*b, a**4, a**2*b**4, b**4, a**3*b**2, a**4*b**2, a*b**4, a**3*b, a**3*b**3, a**2, a*b, a**4*b**3, b**2, a**4*b, a, a**3*b**4, a**4*b**4, a**2*b**2, a*b**2, b, b**3, a**3, a**2*b**3, a*b**3)" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ - "all_monomials = set()\n", - "for i in range(LMAX):\n", - " for j in range(LMAX):\n", - " for s in range(LMAX):\n", - " bf = binom_factor(i, j, a, b, s)\n", - " if bf:\n", - " coefs = all_coeffs(bf)\n", - " if LMAX < 5 or i%2 and j%2:\n", - " display(S(((i, j, s), bf, coefs)))\n", - " all_monomials = all_monomials.union(set(coefs.keys()))\n", - "all_monomials = tuple(all_monomials)\n", - "n = len(all_monomials)\n", - "print(f\"{len(all_monomials)=}\")\n", - "display(S(all_monomials))\n", - "# monom_map = {k:i for i,k in enumerate(all_monomials)}\n", - "\n", - "weights = np.zeros((LMAX, LMAX, LMAX, n))\n", + "weights = np.zeros((LMAX, LMAX, LMAX, LMAX*LMAX))\n", "for i in range(LMAX):\n", " for j in range(LMAX):\n", " for s in range(LMAX):\n", - " bf = binom_factor(i, j, a, b, s)\n", - " coefs = all_coeffs(bf)\n", - " val = [coefs.get(monom, 0) for monom in all_monomials]\n", - " if LMAX < 5:\n", + " bf = binom_factor_sym(i, j, a, b, s)\n", + " val = get_coeffs(bf)\n", + " if np.random.rand()**LMAX > .7:\n", " print((i, j, s), val, bf)\n", " weights[i, j, s, :] = val" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 1. 2.33 5.4289 12.649336 29.472952 1.1\n", + " 2.563 5.97179 13.914269 32.420246 1.21 2.8193\n", + " 6.568969 15.3056965 35.662273 1.3310001 3.1012301 7.225866\n", + " 16.836267 39.2285 1.4641001 3.411353 7.948453 18.519894\n", + " 43.151352 ]\n", + "[ 1. 2.32999992 5.42889977 12.64933576 29.47295135 1.10000002\n", + " 2.56299996 5.97178984 13.91426963 32.42024719 1.21000004 2.81929994\n", + " 6.56896877 15.30569675 35.66227226 1.33100009 3.1012301 7.22586606\n", + " 16.83626699 39.2285008 1.46410013 3.41135318 7.94845284 18.51989409\n", + " 43.15135181]\n" + ] + } + ], + "source": [ + "import jax.numpy as jnp\n", + "def get_monomials(a,b):\n", + " a_pows = a ** jnp.arange(LMAX)\n", + " b_pows = b ** jnp.arange(LMAX)\n", + " ans = a_pows.reshape(LMAX,1) @ b_pows.reshape(1,LMAX)\n", + " return ans.reshape(LMAX*LMAX)\n", + "\n", + "f = lambda x: np.array(x, dtype=np.float32)\n", + "fa,fb = f(1.1),f(2.33)\n", + "got = get_monomials(fa,fb)\n", + "print(got)\n", + "\n", + "expect = lambdify((a,b), all_monomials, \"numpy\")(fa,fb).reshape(LMAX*LMAX)\n", + "print(expect)\n", + "\n", + "np.testing.assert_allclose(got,expect)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -578,8 +342,10 @@ "inds = np.nonzero(weights)\n", "print(f'nnz={len(inds[0])}')\n", "\n", + "import inspect\n", + "\n", "with np.printoptions(threshold=np.inf, formatter={'float':lambda x:f'{x:.10g}'}):\n", - " with open(\"../pyscf_ipu/experimental/binom_factor_table.py\", \"w\") as f:\n", + " with open(\"../pyscf_ipu/experimental/binom_factor_table.py\", \"w\") as file:\n", " print(\n", " f\"\"\"# Copyright (c) 2023 Graphcore Ltd. All rights reserved.\n", "# AUTOGENERATED from notebooks/binom_factor_table.ipynb\n", @@ -588,22 +354,23 @@ "# flake8: noqa\n", "# isort: skip_file\n", "\n", - "from numpy import array,zeros\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "array = np.array\n", "\n", "LMAX = {LMAX}\n", - "def get_monomials(a,b):\n", - " return {all_monomials}\n", + "{inspect.getsource(get_monomials)}\n", "\n", "def build_binom_factor_table(sparse=False):\n", " inds,values = {repr((inds, weights[inds]))}\n", " if sparse:\n", " return inds,values\n", " else:\n", - " W = zeros((LMAX,LMAX,LMAX,LMAX*LMAX))\n", + " W = np.zeros((LMAX,LMAX,LMAX,LMAX*LMAX))\n", " W[inds] = values\n", " return W\n", "\"\"\",\n", - " file=f,\n", + " file=file,\n", " end=''\n", " )" ] @@ -619,36 +386,54 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "monomials=(1, 2.662000000000001, 1.4641000000000004, 28.344976000000013, 23.425600000000006, 6.442040000000003, 7.086244000000003, 25.76816000000001, 2.9282000000000012, 14.172488000000008, 1.2100000000000002, 2.4200000000000004, 15.58973680000001, 4.840000000000001, 3.221020000000001, 1.1, 31.17947360000002, 34.29742096000002, 5.856400000000002, 5.324000000000002, 2.2, 10.648000000000003, 1.3310000000000004, 12.884080000000006, 11.712800000000005)\n" + "monomials=DeviceArray([ 1. , 2.2 , 4.84 , 10.648001 , 23.425602 ,\n", + " 1.1 , 2.42 , 5.3240004, 11.712801 , 25.768162 ,\n", + " 1.21 , 2.6620002, 5.8564005, 12.884081 , 28.344978 ,\n", + " 1.3310001, 2.9282002, 6.4420404, 14.17249 , 31.179478 ,\n", + " 1.4641001, 3.2210202, 7.0862446, 15.589739 , 34.29743 ], dtype=float32)\n", + "(0, 3, 3) 1.0 [1.]\n", + "(0, 4, 2) 29.040000915527344 [29.04]\n", + "(1, 0, 2) 0 [0.]\n", + "(1, 0, 3) 0 [0.]\n", + "(2, 0, 1) 2.200000047683716 [2.2]\n", + "(2, 1, 4) 0 [0.]\n", + "(2, 3, 0) 12.88408124395375 [12.884081]\n", + "(2, 3, 2) 50.578002816677134 [50.578003]\n", + "(2, 4, 0) 28.344979351059116 [28.344978]\n", + "(3, 1, 2) 10.890000429153446 [10.89]\n", + "(3, 1, 3) 5.5000001192092896 [5.5]\n", + "(4, 2, 2) 60.02810437345515 [60.028107]\n", + "(4, 2, 4) 31.460001220703134 [31.460001]\n", + "(4, 3, 2) 164.27203597465098 [164.27203]\n", + "(4, 3, 4) 127.77600698661814 [127.776]\n" ] } ], "source": [ "from pyscf_ipu.experimental import binom_factor_table\n", "\n", - "aval, bval = 1.1, 2.2\n", + "aval, bval = f(1.1), f(2.2)\n", "\n", "monomials = binom_factor_table.get_monomials(aval, bval)\n", "print(f\"{monomials=}\")\n", "\n", - "W = binom_factor_table.build_binom_factor_table()\n", - "\n", - "table_ab = W @ monomials\n", + "W = jnp.array(binom_factor_table.build_binom_factor_table(), dtype=jnp.float32)\n", + "table_ab = W @ monomials.reshape(LMAX*LMAX,1)\n", "\n", "for i in range(LMAX):\n", " for j in range(LMAX):\n", " for s in range(LMAX):\n", " bf = binom_factor(i, j, aval, bval, s)\n", - " if LMAX < 5:\n", + " if np.random.rand() ** LMAX > 0.6:\n", " print((i, j, s), bf, table_ab[i, j, s])\n", - " np.testing.assert_allclose(bf, table_ab[i, j, s])" + " np.testing.assert_allclose(bf, table_ab[i, j, s], rtol=1e-6)" ] }, { diff --git a/pyscf_ipu/experimental/binom_factor_table.py b/pyscf_ipu/experimental/binom_factor_table.py index cf59ce9..53fa3a0 100644 --- a/pyscf_ipu/experimental/binom_factor_table.py +++ b/pyscf_ipu/experimental/binom_factor_table.py @@ -5,60 +5,292 @@ # flake8: noqa # isort: skip_file -from numpy import array,zeros +import jax.numpy as jnp +import numpy as np +array = np.array -LMAX = 5 +LMAX = 8 def get_monomials(a,b): - return (1, a**2*b, a**4, a**2*b**4, b**4, a**3*b**2, a**4*b**2, a*b**4, a**3*b, a**3*b**3, a**2, a*b, a**4*b**3, b**2, a**4*b, a, a**3*b**4, a**4*b**4, a**2*b**2, a*b**2, b, b**3, a**3, a**2*b**3, a*b**3) + a_pows = a ** jnp.arange(LMAX) + b_pows = b ** jnp.arange(LMAX) + ans = a_pows.reshape(LMAX,1) @ b_pows.reshape(1,LMAX) + return ans.reshape(LMAX*LMAX) + def build_binom_factor_table(sparse=False): - inds,values = ((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + inds,values = ((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7]), array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, + 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 1, 1, 1, 1, 2, 2, + 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0, 1, 1, 1, 1, 1, + 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0, 0, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 0, 0, 1, 1, 1, 1, 2, - 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, - 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, - 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, - 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 0, 1, 1, 2, 0, - 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3, 3, 4, 0, 1, 1, 2, 2, 3, 3, 4, 4, - 0, 1, 2, 0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 2, 3, 3, 4, 0, 1, 1, 2, - 2, 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 1, 2, - 3, 0, 1, 1, 2, 2, 3, 3, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 0, 1, - 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, - 4, 4, 4, 0, 1, 2, 3, 4, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 1, 1, 2, 2, - 2, 3, 3, 3, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 0, - 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4]), array([ 0, 20, 0, 13, 20, 0, 21, 13, 20, 0, 4, 21, 13, 20, 0, 15, 0, - 11, 15, 20, 0, 19, 11, 13, 15, 20, 0, 24, 19, 21, 11, 13, 15, 20, - 0, 7, 4, 24, 19, 21, 11, 13, 15, 20, 10, 15, 0, 1, 10, 11, 15, - 20, 0, 18, 1, 19, 10, 11, 13, 15, 20, 0, 23, 18, 24, 1, 19, 21, - 10, 11, 13, 15, 20, 3, 7, 23, 4, 18, 24, 1, 19, 21, 10, 11, 13, - 22, 10, 15, 0, 8, 1, 22, 10, 11, 15, 20, 0, 5, 8, 18, 1, 19, - 22, 10, 11, 13, 15, 20, 9, 5, 23, 8, 18, 24, 1, 19, 21, 22, 10, - 11, 13, 16, 3, 9, 5, 7, 23, 4, 8, 18, 24, 1, 19, 21, 22, 2, - 22, 10, 15, 0, 14, 2, 8, 1, 22, 10, 11, 15, 20, 6, 5, 14, 2, - 8, 18, 1, 19, 22, 10, 11, 13, 12, 6, 9, 5, 14, 23, 2, 8, 18, - 24, 1, 19, 21, 22, 17, 12, 16, 3, 6, 9, 5, 7, 14, 23, 2, 4, - 8, 18, 24])), array([1, 1, 1, 1, 2, 1, 1, 3, 3, 1, 1, 4, 6, 4, 1, 1, 1, 1, 1, 1, 1, 1, - 2, 1, 1, 2, 1, 1, 3, 1, 3, 3, 1, 3, 1, 1, 1, 4, 6, 4, 4, 6, 1, 4, - 1, 2, 1, 1, 1, 2, 2, 1, 1, 1, 2, 2, 1, 4, 1, 2, 2, 1, 1, 3, 2, 3, - 6, 1, 1, 6, 3, 2, 3, 1, 2, 4, 1, 6, 8, 4, 12, 4, 1, 8, 6, 1, 3, 3, - 1, 1, 3, 1, 3, 3, 3, 1, 1, 1, 2, 3, 6, 3, 1, 3, 6, 1, 3, 2, 1, 3, - 3, 3, 9, 3, 9, 9, 1, 1, 3, 9, 3, 1, 3, 4, 6, 3, 12, 1, 4, 18, 12, - 12, 18, 4, 1, 1, 4, 6, 4, 1, 1, 1, 4, 6, 4, 6, 4, 4, 1, 1, 4, 2, 1, - 8, 6, 12, 4, 4, 6, 8, 1, 1, 3, 4, 12, 3, 6, 1, 12, 18, 4, 18, 12, - 1, 4, 1, 4, 4, 6, 6, 16, 24, 4, 4, 24, 1, 1, 16, 36, 16])) + 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 0, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7]), array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 0, + 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 0, 1, 1, 2, 0, 1, + 1, 2, 2, 3, 0, 1, 1, 2, 2, 3, 3, 4, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, + 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, + 5, 6, 6, 7, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 0, 1, 2, + 0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 2, 3, 3, 4, 0, 1, 1, 2, 2, 2, 3, + 3, 3, 4, 4, 5, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 6, 0, 1, + 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 7, 0, 1, 1, 2, 2, 2, + 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, + 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 0, 1, 2, 3, 0, 1, 1, 2, 2, + 3, 3, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 5, 0, 1, 1, 2, 2, 2, 3, + 3, 3, 3, 4, 4, 4, 5, 5, 6, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, + 4, 5, 5, 5, 6, 6, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, + 5, 5, 5, 6, 6, 6, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, + 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, + 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 2, 3, 4, 0, 1, + 1, 2, 2, 3, 3, 4, 4, 5, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, + 6, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 6, 6, 7, 0, + 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, + 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, + 6, 6, 6, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, + 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, + 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 0, + 1, 2, 3, 4, 5, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 0, 1, 1, 2, 2, + 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, + 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, + 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 0, 1, 1, + 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, + 6, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, + 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, + 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, + 7, 7, 7, 7, 0, 1, 2, 3, 4, 5, 6, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, + 6, 6, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, + 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, + 6, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, + 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, + 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 0, + 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, + 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, + 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, + 7, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, + 6, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, + 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, + 6, 6, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, + 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, + 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 7, 7, 7, + 7, 7, 7, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, + 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 0, 1, 1, 2, 2, 2, + 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7]), array([ 0, 1, 0, 2, 1, 0, 3, 2, 1, 0, 4, 3, 2, 1, 0, 5, 4, + 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 7, 6, 5, 4, 3, 2, + 1, 0, 8, 0, 9, 1, 8, 0, 10, 2, 9, 1, 8, 0, 11, 3, 10, + 2, 9, 1, 8, 0, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0, 13, 5, + 12, 4, 11, 3, 10, 2, 9, 1, 8, 0, 14, 6, 13, 5, 12, 4, 11, + 3, 10, 2, 9, 1, 8, 0, 15, 7, 14, 6, 13, 5, 12, 4, 11, 3, + 10, 2, 9, 1, 8, 16, 8, 0, 17, 9, 16, 1, 8, 0, 18, 10, 17, + 2, 9, 16, 1, 8, 0, 19, 11, 18, 3, 10, 17, 2, 9, 16, 1, 8, + 0, 20, 12, 19, 4, 11, 18, 3, 10, 17, 2, 9, 16, 1, 8, 0, 21, + 13, 20, 5, 12, 19, 4, 11, 18, 3, 10, 17, 2, 9, 16, 1, 8, 0, + 22, 14, 21, 6, 13, 20, 5, 12, 19, 4, 11, 18, 3, 10, 17, 2, 9, + 16, 1, 8, 23, 15, 22, 7, 14, 21, 6, 13, 20, 5, 12, 19, 4, 11, + 18, 3, 10, 17, 2, 9, 16, 24, 16, 8, 0, 25, 17, 24, 9, 16, 1, + 8, 0, 26, 18, 25, 10, 17, 24, 2, 9, 16, 1, 8, 0, 27, 19, 26, + 11, 18, 25, 3, 10, 17, 24, 2, 9, 16, 1, 8, 0, 28, 20, 27, 12, + 19, 26, 4, 11, 18, 25, 3, 10, 17, 24, 2, 9, 16, 1, 8, 0, 29, + 21, 28, 13, 20, 27, 5, 12, 19, 26, 4, 11, 18, 25, 3, 10, 17, 24, + 2, 9, 16, 1, 8, 30, 22, 29, 14, 21, 28, 6, 13, 20, 27, 5, 12, + 19, 26, 4, 11, 18, 25, 3, 10, 17, 24, 2, 9, 16, 31, 23, 30, 15, + 22, 29, 7, 14, 21, 28, 6, 13, 20, 27, 5, 12, 19, 26, 4, 11, 18, + 25, 3, 10, 17, 24, 32, 24, 16, 8, 0, 33, 25, 32, 17, 24, 9, 16, + 1, 8, 0, 34, 26, 33, 18, 25, 32, 10, 17, 24, 2, 9, 16, 1, 8, + 0, 35, 27, 34, 19, 26, 33, 11, 18, 25, 32, 3, 10, 17, 24, 2, 9, + 16, 1, 8, 0, 36, 28, 35, 20, 27, 34, 12, 19, 26, 33, 4, 11, 18, + 25, 32, 3, 10, 17, 24, 2, 9, 16, 1, 8, 37, 29, 36, 21, 28, 35, + 13, 20, 27, 34, 5, 12, 19, 26, 33, 4, 11, 18, 25, 32, 3, 10, 17, + 24, 2, 9, 16, 38, 30, 37, 22, 29, 36, 14, 21, 28, 35, 6, 13, 20, + 27, 34, 5, 12, 19, 26, 33, 4, 11, 18, 25, 32, 3, 10, 17, 24, 39, + 31, 38, 23, 30, 37, 15, 22, 29, 36, 7, 14, 21, 28, 35, 6, 13, 20, + 27, 34, 5, 12, 19, 26, 33, 4, 11, 18, 25, 32, 40, 32, 24, 16, 8, + 0, 41, 33, 40, 25, 32, 17, 24, 9, 16, 1, 8, 0, 42, 34, 41, 26, + 33, 40, 18, 25, 32, 10, 17, 24, 2, 9, 16, 1, 8, 0, 43, 35, 42, + 27, 34, 41, 19, 26, 33, 40, 11, 18, 25, 32, 3, 10, 17, 24, 2, 9, + 16, 1, 8, 44, 36, 43, 28, 35, 42, 20, 27, 34, 41, 12, 19, 26, 33, + 40, 4, 11, 18, 25, 32, 3, 10, 17, 24, 2, 9, 16, 45, 37, 44, 29, + 36, 43, 21, 28, 35, 42, 13, 20, 27, 34, 41, 5, 12, 19, 26, 33, 40, + 4, 11, 18, 25, 32, 3, 10, 17, 24, 46, 38, 45, 30, 37, 44, 22, 29, + 36, 43, 14, 21, 28, 35, 42, 6, 13, 20, 27, 34, 41, 5, 12, 19, 26, + 33, 40, 4, 11, 18, 25, 32, 47, 39, 46, 31, 38, 45, 23, 30, 37, 44, + 15, 22, 29, 36, 43, 7, 14, 21, 28, 35, 42, 6, 13, 20, 27, 34, 41, + 5, 12, 19, 26, 33, 40, 48, 40, 32, 24, 16, 8, 0, 49, 41, 48, 33, + 40, 25, 32, 17, 24, 9, 16, 1, 8, 0, 50, 42, 49, 34, 41, 48, 26, + 33, 40, 18, 25, 32, 10, 17, 24, 2, 9, 16, 1, 8, 51, 43, 50, 35, + 42, 49, 27, 34, 41, 48, 19, 26, 33, 40, 11, 18, 25, 32, 3, 10, 17, + 24, 2, 9, 16, 52, 44, 51, 36, 43, 50, 28, 35, 42, 49, 20, 27, 34, + 41, 48, 12, 19, 26, 33, 40, 4, 11, 18, 25, 32, 3, 10, 17, 24, 53, + 45, 52, 37, 44, 51, 29, 36, 43, 50, 21, 28, 35, 42, 49, 13, 20, 27, + 34, 41, 48, 5, 12, 19, 26, 33, 40, 4, 11, 18, 25, 32, 54, 46, 53, + 38, 45, 52, 30, 37, 44, 51, 22, 29, 36, 43, 50, 14, 21, 28, 35, 42, + 49, 6, 13, 20, 27, 34, 41, 48, 5, 12, 19, 26, 33, 40, 55, 47, 54, + 39, 46, 53, 31, 38, 45, 52, 23, 30, 37, 44, 51, 15, 22, 29, 36, 43, + 50, 7, 14, 21, 28, 35, 42, 49, 6, 13, 20, 27, 34, 41, 48, 56, 48, + 40, 32, 24, 16, 8, 0, 57, 49, 56, 41, 48, 33, 40, 25, 32, 17, 24, + 9, 16, 1, 8, 58, 50, 57, 42, 49, 56, 34, 41, 48, 26, 33, 40, 18, + 25, 32, 10, 17, 24, 2, 9, 16, 59, 51, 58, 43, 50, 57, 35, 42, 49, + 56, 27, 34, 41, 48, 19, 26, 33, 40, 11, 18, 25, 32, 3, 10, 17, 24, + 60, 52, 59, 44, 51, 58, 36, 43, 50, 57, 28, 35, 42, 49, 56, 20, 27, + 34, 41, 48, 12, 19, 26, 33, 40, 4, 11, 18, 25, 32, 61, 53, 60, 45, + 52, 59, 37, 44, 51, 58, 29, 36, 43, 50, 57, 21, 28, 35, 42, 49, 56, + 13, 20, 27, 34, 41, 48, 5, 12, 19, 26, 33, 40, 62, 54, 61, 46, 53, + 60, 38, 45, 52, 59, 30, 37, 44, 51, 58, 22, 29, 36, 43, 50, 57, 14, + 21, 28, 35, 42, 49, 56, 6, 13, 20, 27, 34, 41, 48, 63, 55, 62, 47, + 54, 61, 39, 46, 53, 60, 31, 38, 45, 52, 59, 23, 30, 37, 44, 51, 58, + 15, 22, 29, 36, 43, 50, 57, 7, 14, 21, 28, 35, 42, 49, 56])), array([1, 1, 1, 1, 2, 1, 1, 3, 3, 1, 1, 4, 6, 4, 1, 1, 5, 10, 10, 5, 1, 1, + 6, 15, 20, 15, 6, 1, 1, 7, 21, 35, 35, 21, 7, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3, 1, 1, 1, 1, 4, 4, 6, 6, 4, 4, + 1, 1, 1, 1, 5, 5, 10, 10, 10, 10, 5, 5, 1, 1, 1, 1, 6, 6, 15, 15, + 20, 20, 15, 15, 6, 6, 1, 1, 1, 1, 7, 7, 21, 21, 35, 35, 35, 35, 21, + 21, 7, 7, 1, 1, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 4, 1, 2, 2, 1, + 1, 2, 3, 1, 6, 3, 3, 6, 1, 3, 2, 1, 1, 2, 4, 1, 8, 6, 4, 12, 4, 6, + 8, 1, 4, 2, 1, 1, 2, 5, 1, 10, 10, 5, 20, 10, 10, 20, 5, 10, 10, 1, + 5, 2, 1, 1, 2, 6, 1, 12, 15, 6, 30, 20, 15, 40, 15, 20, 30, 6, 15, + 12, 1, 6, 2, 1, 2, 7, 1, 14, 21, 7, 42, 35, 21, 70, 35, 35, 70, 21, + 35, 42, 7, 21, 14, 1, 1, 3, 3, 1, 1, 3, 1, 3, 3, 1, 3, 1, 1, 3, 2, + 3, 6, 1, 1, 6, 3, 2, 3, 1, 1, 3, 3, 3, 9, 3, 1, 9, 9, 1, 3, 9, 3, + 3, 3, 1, 1, 3, 4, 3, 12, 6, 1, 12, 18, 4, 4, 18, 12, 1, 6, 12, 3, + 4, 3, 1, 1, 3, 5, 3, 15, 10, 1, 15, 30, 10, 5, 30, 30, 5, 10, 30, + 15, 1, 10, 15, 3, 5, 3, 1, 3, 6, 3, 18, 15, 1, 18, 45, 20, 6, 45, + 60, 15, 15, 60, 45, 6, 20, 45, 18, 1, 15, 18, 3, 1, 3, 7, 3, 21, + 21, 1, 21, 63, 35, 7, 63, 105, 35, 21, 105, 105, 21, 35, 105, 63, + 7, 35, 63, 21, 1, 1, 4, 6, 4, 1, 1, 4, 1, 6, 4, 4, 6, 1, 4, 1, 1, + 4, 2, 6, 8, 1, 4, 12, 4, 1, 8, 6, 2, 4, 1, 1, 4, 3, 6, 12, 3, 4, + 18, 12, 1, 1, 12, 18, 4, 3, 12, 6, 3, 4, 1, 1, 4, 4, 6, 16, 6, 4, + 24, 24, 4, 1, 16, 36, 16, 1, 4, 24, 24, 4, 6, 16, 6, 4, 4, 1, 4, 5, + 6, 20, 10, 4, 30, 40, 10, 1, 20, 60, 40, 5, 5, 40, 60, 20, 1, 10, + 40, 30, 4, 10, 20, 6, 1, 4, 6, 6, 24, 15, 4, 36, 60, 20, 1, 24, 90, + 80, 15, 6, 60, 120, 60, 6, 15, 80, 90, 24, 1, 20, 60, 36, 4, 1, 4, + 7, 6, 28, 21, 4, 42, 84, 35, 1, 28, 126, 140, 35, 7, 84, 210, 140, + 21, 21, 140, 210, 84, 7, 35, 140, 126, 28, 1, 1, 5, 10, 10, 5, 1, + 1, 5, 1, 10, 5, 10, 10, 5, 10, 1, 5, 1, 1, 5, 2, 10, 10, 1, 10, 20, + 5, 5, 20, 10, 1, 10, 10, 2, 5, 1, 1, 5, 3, 10, 15, 3, 10, 30, 15, + 1, 5, 30, 30, 5, 1, 15, 30, 10, 3, 15, 10, 3, 5, 1, 5, 4, 10, 20, + 6, 10, 40, 30, 4, 5, 40, 60, 20, 1, 1, 20, 60, 40, 5, 4, 30, 40, + 10, 6, 20, 10, 1, 5, 5, 10, 25, 10, 10, 50, 50, 10, 5, 50, 100, 50, + 5, 1, 25, 100, 100, 25, 1, 5, 50, 100, 50, 5, 10, 50, 50, 10, 1, 5, + 6, 10, 30, 15, 10, 60, 75, 20, 5, 60, 150, 100, 15, 1, 30, 150, + 200, 75, 6, 6, 75, 200, 150, 30, 1, 15, 100, 150, 60, 5, 1, 5, 7, + 10, 35, 21, 10, 70, 105, 35, 5, 70, 210, 175, 35, 1, 35, 210, 350, + 175, 21, 7, 105, 350, 350, 105, 7, 21, 175, 350, 210, 35, 1, 1, 6, + 15, 20, 15, 6, 1, 1, 6, 1, 15, 6, 20, 15, 15, 20, 6, 15, 1, 6, 1, + 1, 6, 2, 15, 12, 1, 20, 30, 6, 15, 40, 15, 6, 30, 20, 1, 12, 15, 2, + 6, 1, 6, 3, 15, 18, 3, 20, 45, 18, 1, 15, 60, 45, 6, 6, 45, 60, 15, + 1, 18, 45, 20, 3, 18, 15, 1, 6, 4, 15, 24, 6, 20, 60, 36, 4, 15, + 80, 90, 24, 1, 6, 60, 120, 60, 6, 1, 24, 90, 80, 15, 4, 36, 60, 20, + 1, 6, 5, 15, 30, 10, 20, 75, 60, 10, 15, 100, 150, 60, 5, 6, 75, + 200, 150, 30, 1, 1, 30, 150, 200, 75, 6, 5, 60, 150, 100, 15, 1, 6, + 6, 15, 36, 15, 20, 90, 90, 20, 15, 120, 225, 120, 15, 6, 90, 300, + 300, 90, 6, 1, 36, 225, 400, 225, 36, 1, 6, 90, 300, 300, 90, 6, 1, + 6, 7, 15, 42, 21, 20, 105, 126, 35, 15, 140, 315, 210, 35, 6, 105, + 420, 525, 210, 21, 1, 42, 315, 700, 525, 126, 7, 7, 126, 525, 700, + 315, 42, 1, 1, 7, 21, 35, 35, 21, 7, 1, 1, 7, 1, 21, 7, 35, 21, 35, + 35, 21, 35, 7, 21, 1, 7, 1, 7, 2, 21, 14, 1, 35, 42, 7, 35, 70, 21, + 21, 70, 35, 7, 42, 35, 1, 14, 21, 1, 7, 3, 21, 21, 3, 35, 63, 21, + 1, 35, 105, 63, 7, 21, 105, 105, 21, 7, 63, 105, 35, 1, 21, 63, 35, + 1, 7, 4, 21, 28, 6, 35, 84, 42, 4, 35, 140, 126, 28, 1, 21, 140, + 210, 84, 7, 7, 84, 210, 140, 21, 1, 28, 126, 140, 35, 1, 7, 5, 21, + 35, 10, 35, 105, 70, 10, 35, 175, 210, 70, 5, 21, 175, 350, 210, + 35, 1, 7, 105, 350, 350, 105, 7, 1, 35, 210, 350, 175, 21, 1, 7, 6, + 21, 42, 15, 35, 126, 105, 20, 35, 210, 315, 140, 15, 21, 210, 525, + 420, 105, 6, 7, 126, 525, 700, 315, 42, 1, 1, 42, 315, 700, 525, + 126, 7, 1, 7, 7, 21, 49, 21, 35, 147, 147, 35, 35, 245, 441, 245, + 35, 21, 245, 735, 735, 245, 21, 7, 147, 735, 1225, 735, 147, 7, 1, + 49, 441, 1225, 1225, 441, 49, 1])) if sparse: return inds,values else: - W = zeros((LMAX,LMAX,LMAX,LMAX*LMAX)) + W = np.zeros((LMAX,LMAX,LMAX,LMAX*LMAX)) W[inds] = values return W diff --git a/pyscf_ipu/experimental/integrals.py b/pyscf_ipu/experimental/integrals.py index 26ed073..74bb38e 100644 --- a/pyscf_ipu/experimental/integrals.py +++ b/pyscf_ipu/experimental/integrals.py @@ -12,7 +12,13 @@ from .basis import Basis from .orbital import batch_orbitals from .primitive import Primitive, product -from .special import binom, binom_factor, factorial, factorial2, gammanu +from .special import ( + binom, + binom_factor__via_segment_sum, + factorial, + factorial2, + gammanu, +) from .types import Float3, FloatN, FloatNx3, FloatNxN from .units import LMAX @@ -112,7 +118,9 @@ def build_gindex(): return i, r, u -def _nuclear_primitives(a: Primitive, b: Primitive, c: Float3): +def _nuclear_primitives( + a: Primitive, b: Primitive, c: Float3, binom_factor=binom_factor__via_segment_sum +): p = product(a, b) pa = p.center - a.center pb = p.center - b.center @@ -155,7 +163,7 @@ def g_term(l1, l2, pa, pb, cp): overlap_primitives = jit(_overlap_primitives) kinetic_primitives = jit(_kinetic_primitives) -nuclear_primitives = jit(_nuclear_primitives) +nuclear_primitives = jit(_nuclear_primitives, static_argnames="binom_factor") vmap_overlap_primitives = jit(vmap(_overlap_primitives)) vmap_kinetic_primitives = jit(vmap(_kinetic_primitives)) @@ -183,7 +191,13 @@ def build_cindex(): return i1, i2, r1, r2, u -def _eri_primitives(a: Primitive, b: Primitive, c: Primitive, d: Primitive) -> float: +def _eri_primitives( + a: Primitive, + b: Primitive, + c: Primitive, + d: Primitive, + binom_factor=binom_factor__via_segment_sum, +) -> float: p = product(a, b) q = product(c, d) pa = p.center - a.center @@ -244,7 +258,7 @@ def c_term(la, lb, lc, ld, pa, pb, qc, qd, qp): ) -eri_primitives = jit(_eri_primitives) +eri_primitives = jit(_eri_primitives, static_argnames="binom_factor") vmap_eri_primitives = jit(vmap(_eri_primitives)) @@ -260,7 +274,7 @@ def gen_ijkl(n: int): yield idx, jdx, kdx, ldx -def eri_basis_sparse(b: Basis): +def eri_basis_sparse(b: Basis, binom_factor=binom_factor__via_segment_sum): indices = [] batch = [] offset = np.cumsum([o.num_primitives for o in b.orbitals]) @@ -282,8 +296,8 @@ def eri_basis_sparse(b: Basis): return segment_sum(eris, batch, num_segments=count + 1) -def eri_basis(b: Basis): - unique_eris = eri_basis_sparse(b) +def eri_basis(b: Basis, binom_factor=binom_factor__via_segment_sum): + unique_eris = eri_basis_sparse(b, binom_factor) ii, jj, kk, ll = jnp.array(list(gen_ijkl(b.num_orbitals)), dtype=jnp.int32).T # Apply 8x permutation symmetry to build dense ERI from sparse ERI. diff --git a/test/test_integrals.py b/test/test_integrals.py index 00b359a..ed54859 100644 --- a/test/test_integrals.py +++ b/test/test_integrals.py @@ -4,6 +4,7 @@ import pytest from numpy.testing import assert_allclose +import pyscf_ipu.experimental as pyscf_experimental from pyscf_ipu.experimental.basis import basisset from pyscf_ipu.experimental.integrals import ( eri_basis, @@ -78,11 +79,49 @@ def test_water_kinetic(basis_name): assert_allclose(actual, expect, atol=1e-4) -def test_nuclear(): +def check_recompile(recompile, function): + # Force recompile + if recompile == "recompile": + # TBH, this is a bit of a red herring - it will force recompilation, + # but the whole switch is only really useful if the False case + # runs after the true case in the same process + # i.e. timing from + # pytest -k test_nuclear[lookup-recompile] --durations=5 + # will be the same as + # pytest -k test_nuclear[lookup-cached] --durations=5 + # While + # pytest -k test_nuclear[lookup- --durations=5 + # will show both times, and cached will be lower + function._clear_cache() + + +@pytest.mark.parametrize("recompile", ["recompile", "cached"]) +@pytest.mark.parametrize("binom_factor_str", ["segment_sum", "lookup"]) +def test_nuclear(binom_factor_str, recompile): # PyQuante test case for nuclear attraction integral p = Primitive() c = jnp.zeros(3) - assert_allclose(nuclear_primitives(p, p, c), -1.595769, atol=1e-5) + + # Choose the implementation of binom_factor + if binom_factor_str == "segment_sum": + binom_factor = pyscf_experimental.special.binom_factor__via_segment_sum + elif binom_factor_str == "lookup": + binom_factor = pyscf_experimental.special.binom_factor__via_lookup + else: + assert False + + check_recompile(recompile, nuclear_primitives) + assert_allclose(nuclear_primitives(p, p, c, binom_factor), -1.595769, atol=1e-5) + + # if recompile == 'recompile': + # from jaxutils.jaxpr_to_expr import show_jaxpr + # show_jaxpr( + # nuclear_primitives, + # (p, p, c, binom_factor), + # file=f"tmp/nuclear_primitives_jaxpr__binom_factor__via_{binom_factor_str}.py", + # optimize=False, + # static_argnums=3, + # ) # Reproduce the nuclear attraction matrix for H2 using STO-3G basis set # See equation 3.231 and 3.232 of Szabo and Ostlund @@ -141,13 +180,24 @@ def is_mem_limited(): return total_mem_gib < 10 -@pytest.mark.parametrize("sparse", [True, False]) +@pytest.mark.parametrize("recompile", ["recompile", "cached"]) +@pytest.mark.parametrize("binom_factor_str", ["segment_sum", "lookup"]) +@pytest.mark.parametrize("sparsity", ["sparse", "dense"]) @pytest.mark.skipif(is_mem_limited(), reason="Not enough host memory!") -def test_water_eri(sparse): +def test_water_eri(recompile, binom_factor_str, sparsity): + sparse = sparsity == "sparse" + check_recompile(recompile, eri_primitives) + binom_factor = eval( + "pyscf_experimental.special.binom_factor__via_" + binom_factor_str + ) + basis_name = "sto-3g" h2o = molecule("water") basis = basisset(h2o, basis_name) - actual = eri_basis_sparse(basis) if sparse else eri_basis(basis) + if sparse: + actual = eri_basis_sparse(basis, binom_factor) + else: + actual = eri_basis(basis, binom_factor) aosym = "s8" if sparse else "s1" expect = to_pyscf(h2o, basis_name=basis_name).intor("int2e_cart", aosym=aosym) assert_allclose(actual, expect, atol=1e-4)