diff --git a/notebooks/binom_factor_table.ipynb b/notebooks/binom_factor_table.ipynb new file mode 100644 index 0000000..ee2258e --- /dev/null +++ b/notebooks/binom_factor_table.ipynb @@ -0,0 +1,532 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from jax.scipy.special import betaln, gammainc, gammaln\n", + "import numpy as np\n", + "from sympy import *" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Table lookup for \"binom factor\"\n", + "\n", + "The function called \"binom factor\" is one of those with a variable-sized loop, \n", + "which is tricky to vectorize. Its definition is\n", + "$$\n", + "B(i,j,a,b,s) = \\sum_{n=s-i}^j \\binom{i}{s-n} \\binom{j}{n} a^{i-(s-n)} b^{j-n}\n", + "$$\n", + "\n", + "Its arguments are three integers $(i,j,s)$ and two floats. It can't be a lookup table because of the floats, and it can't be a simple table of functions because JAX would not be impressed.\n", + "\n", + "But... For a given $i,j,s$ we know it will be a polynomial in $(a,b)$:\n", + "$$\n", + "B(i,j,a,b,s) = \\sum_{p=1}^{max} \\sum_{q=1}^{max} w^{i,j,s}_{p,q} a^p b^q\n", + "$$\n", + "\n", + "This notebook computes those weights. If max=5, that is a 5x5x5x25 array." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kilobytes=12.20703125\n" + ] + } + ], + "source": [ + "kilobytes = 5*5*5*25 * 4 / 1024\n", + "print(f'{kilobytes=}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In fact, they are all small integers and it is sparse, so it could be less, but certainly not a huge memory burden." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(0, 0, 0) 1\n", + "(0, 0, 1) 0\n", + "(0, 0, 2) 0\n", + "(0, 0, 3) 0\n", + "(0, 1, 0) b\n", + "(0, 1, 1) 1\n", + "(0, 1, 2) 0\n", + "(0, 1, 3) 0\n", + "(0, 2, 0) b**2\n", + "(0, 2, 1) 2*b\n", + "(0, 2, 2) 1\n", + "(0, 2, 3) 0\n", + "(0, 3, 0) b**3\n", + "(0, 3, 1) 3*b**2\n", + "(0, 3, 2) 3*b\n", + "(0, 3, 3) 1\n", + "(1, 0, 0) a\n", + "(1, 0, 1) 1\n", + "(1, 0, 2) 0\n", + "(1, 0, 3) 0\n", + "(1, 1, 0) a*b\n", + "(1, 1, 1) a + b\n", + "(1, 1, 2) 1\n", + "(1, 1, 3) 0\n", + "(1, 2, 0) a*b**2\n", + "(1, 2, 1) 2*a*b + b**2\n", + "(1, 2, 2) a + 2*b\n", + "(1, 2, 3) 1\n", + "(1, 3, 0) a*b**3\n", + "(1, 3, 1) 3*a*b**2 + b**3\n", + "(1, 3, 2) 3*a*b + 3*b**2\n", + "(1, 3, 3) a + 3*b\n", + "(2, 0, 0) a**2\n", + "(2, 0, 1) 2*a\n", + "(2, 0, 2) 1\n", + "(2, 0, 3) 0\n", + "(2, 1, 0) a**2*b\n", + "(2, 1, 1) a**2 + 2*a*b\n", + "(2, 1, 2) 2*a + b\n", + "(2, 1, 3) 1\n", + "(2, 2, 0) a**2*b**2\n", + "(2, 2, 1) 2*a**2*b + 2*a*b**2\n", + "(2, 2, 2) a**2 + 4*a*b + b**2\n", + "(2, 2, 3) 2*a + 2*b\n", + "(2, 3, 0) a**2*b**3\n", + "(2, 3, 1) 3*a**2*b**2 + 2*a*b**3\n", + "(2, 3, 2) 3*a**2*b + 6*a*b**2 + b**3\n", + "(2, 3, 3) a**2 + 6*a*b + 3*b**2\n", + "(3, 0, 0) a**3\n", + "(3, 0, 1) 3*a**2\n", + "(3, 0, 2) 3*a\n", + "(3, 0, 3) 1\n", + "(3, 1, 0) a**3*b\n", + "(3, 1, 1) a**3 + 3*a**2*b\n", + "(3, 1, 2) 3*a**2 + 3*a*b\n", + "(3, 1, 3) 3*a + b\n", + "(3, 2, 0) a**3*b**2\n", + "(3, 2, 1) 2*a**3*b + 3*a**2*b**2\n", + "(3, 2, 2) a**3 + 6*a**2*b + 3*a*b**2\n", + "(3, 2, 3) 3*a**2 + 6*a*b + b**2\n", + "(3, 3, 0) a**3*b**3\n", + "(3, 3, 1) 3*a**3*b**2 + 3*a**2*b**3\n", + "(3, 3, 2) 3*a**3*b + 9*a**2*b**2 + 3*a*b**3\n", + "(3, 3, 3) a**3 + 9*a**2*b + 9*a*b**2 + b**3\n" + ] + }, + { + "data": { + "text/latex": [ + "$\\displaystyle a^{3} + 9 a^{2} b + 9 a b^{2} + b^{3}$" + ], + "text/plain": [ + "a**3 + 9*a**2*b + 9*a*b**2 + b**3" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "def binom(x, y):\n", + " approx = 1.0 / ((x + 1) * np.exp(betaln(x - y + 1, y + 1)))\n", + " return int(np.rint(approx))\n", + "\n", + "def binom_factor(i: int, j: int, a: float, b: float, s: int):\n", + " out = 0\n", + " for t in range(max(s - i, 0), j + 1):\n", + " assert ((s - i) <= t) & (t <= j)\n", + " val = binom(i, s - t) * binom(j, t) * a ** (i - (s - t)) * b ** (j - t)\n", + " out += val\n", + " return out\n", + "\n", + "a,b = symbols(\"a b\", real=True)\n", + "LMAX = 4\n", + "for i in range(LMAX):\n", + " for j in range(LMAX):\n", + " for s in range(LMAX):\n", + " bf = binom_factor(i,j,a,b,s)\n", + " print((i,j,s),bf)\n", + "\n", + "bf" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{1: 0,\n", + " b**2: 0,\n", + " b**3: 1,\n", + " a**2: 0,\n", + " a**3: 1,\n", + " a**2*b**2: 0,\n", + " a**3*b**2: 0,\n", + " a*b**2: 9,\n", + " a**2*b**3: 0,\n", + " a**3*b**3: 0,\n", + " a*b**3: 0,\n", + " a**2*b: 9,\n", + " a**3*b: 0,\n", + " a*b: 0,\n", + " b: 0,\n", + " a: 0}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# https://stackoverflow.com/questions/74731353/is-there-any-all-coeffs-for-multivariable-polynomials-in-sympy\n", + "def all_coeffs(expr,*free):\n", + " x = IndexedBase('x')\n", + " expr = expr.expand()\n", + " free = list(free) or list(expr.free_symbols)\n", + " pows = [p.as_base_exp() for p in expr.atoms(Pow,Symbol)]\n", + " P = {}\n", + " for p,e in pows:\n", + " if p not in free:\n", + " continue\n", + " elif p not in P:\n", + " P[p]=e\n", + " elif e>P[p]:\n", + " P[p] = e\n", + " reps = dict([(f, x[i]) for i,f in enumerate(free)])\n", + " xzero = dict([(v,0) for k,v in reps.items()])\n", + " e = expr.xreplace(reps); reps = {v:k for k,v in reps.items()}\n", + " return dict([(m.xreplace(reps), e.coeff(m).xreplace(xzero) if m!=1 else e.xreplace(xzero)) for m in monoms(*[P[f] for f in free])])\n", + "\n", + "def monoms(*o):\n", + " x = IndexedBase('x')\n", + " f = []\n", + " for i,o in enumerate(o):\n", + " f.append(Poly([1]*(o+1),x[i]).as_expr())\n", + " return Mul(*f).expand().args\n", + "\n", + "all_coeffs(S(bf))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "25" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/latex": [ + "$\\displaystyle \\left( 1, \\ a b^{2}, \\ b^{4}, \\ a b, \\ b, \\ a^{3} b^{2}, \\ a^{4} b^{3}, \\ a^{4}, \\ a^{4} b^{4}, \\ a^{2} b^{3}, \\ a^{2} b, \\ b^{3}, \\ a^{2} b^{4}, \\ a^{3}, \\ a^{3} b, \\ a, \\ a^{3} b^{3}, \\ a b^{3}, \\ a^{3} b^{4}, \\ a b^{4}, \\ a^{4} b^{2}, \\ b^{2}, \\ a^{2}, \\ a^{2} b^{2}, \\ a^{4} b\\right)$" + ], + "text/plain": [ + "(1, a*b**2, b**4, a*b, b, a**3*b**2, a**4*b**3, a**4, a**4*b**4, a**2*b**3, a**2*b, b**3, a**2*b**4, a**3, a**3*b, a, a**3*b**3, a*b**3, a**3*b**4, a*b**4, a**4*b**2, b**2, a**2, a**2*b**2, a**4*b)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(0, 0, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(0, 0, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 0, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 0, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 0, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 1, 0) [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] b\n", + "(0, 1, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(0, 1, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 1, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 1, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 2, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] b**2\n", + "(0, 2, 1) [0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 2*b\n", + "(0, 2, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(0, 2, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 2, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 3, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] b**3\n", + "(0, 3, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0] 3*b**2\n", + "(0, 3, 2) [0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*b\n", + "(0, 3, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(0, 3, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(0, 4, 0) [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] b**4\n", + "(0, 4, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*b**3\n", + "(0, 4, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0] 6*b**2\n", + "(0, 4, 3) [0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*b\n", + "(0, 4, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(1, 0, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] a\n", + "(1, 0, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(1, 0, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(1, 0, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(1, 0, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(1, 1, 0) [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a*b\n", + "(1, 1, 1) [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] a + b\n", + "(1, 1, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(1, 1, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(1, 1, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(1, 2, 0) [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a*b**2\n", + "(1, 2, 1) [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0] 2*a*b + b**2\n", + "(1, 2, 2) [0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] a + 2*b\n", + "(1, 2, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(1, 2, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(1, 3, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0] a*b**3\n", + "(1, 3, 1) [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*a*b**2 + b**3\n", + "(1, 3, 2) [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0] 3*a*b + 3*b**2\n", + "(1, 3, 3) [0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] a + 3*b\n", + "(1, 3, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(1, 4, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0] a*b**4\n", + "(1, 4, 1) [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0] 4*a*b**3 + b**4\n", + "(1, 4, 2) [0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 6*a*b**2 + 4*b**3\n", + "(1, 4, 3) [0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0] 4*a*b + 6*b**2\n", + "(1, 4, 4) [0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0] a + 4*b\n", + "(2, 0, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] a**2\n", + "(2, 0, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] 2*a\n", + "(2, 0, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(2, 0, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(2, 0, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(2, 1, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**2*b\n", + "(2, 1, 1) [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0] a**2 + 2*a*b\n", + "(2, 1, 2) [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] 2*a + b\n", + "(2, 1, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(2, 1, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(2, 2, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0] a**2*b**2\n", + "(2, 2, 1) [0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 2*a**2*b + 2*a*b**2\n", + "(2, 2, 2) [0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0] a**2 + 4*a*b + b**2\n", + "(2, 2, 3) [0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] 2*a + 2*b\n", + "(2, 2, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(2, 3, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**2*b**3\n", + "(2, 3, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 3, 0] 3*a**2*b**2 + 2*a*b**3\n", + "(2, 3, 2) [0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*a**2*b + 6*a*b**2 + b**3\n", + "(2, 3, 3) [0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 0, 0] a**2 + 6*a*b + 3*b**2\n", + "(2, 3, 4) [0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0] 2*a + 3*b\n", + "(2, 4, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**2*b**4\n", + "(2, 4, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0] 4*a**2*b**3 + 2*a*b**4\n", + "(2, 4, 2) [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 6, 0] 6*a**2*b**2 + 8*a*b**3 + b**4\n", + "(2, 4, 3) [0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a**2*b + 12*a*b**2 + 4*b**3\n", + "(2, 4, 4) [0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 1, 0, 0] a**2 + 8*a*b + 6*b**2\n", + "(3, 0, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3\n", + "(3, 0, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0] 3*a**2\n", + "(3, 0, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*a\n", + "(3, 0, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(3, 0, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 0\n", + "(3, 1, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3*b\n", + "(3, 1, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3 + 3*a**2*b\n", + "(3, 1, 2) [0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0] 3*a**2 + 3*a*b\n", + "(3, 1, 3) [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*a + b\n", + "(3, 1, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(3, 2, 0) [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3*b**2\n", + "(3, 2, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0] 2*a**3*b + 3*a**2*b**2\n", + "(3, 2, 2) [0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3 + 6*a**2*b + 3*a*b**2\n", + "(3, 2, 3) [0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 3, 0, 0] 3*a**2 + 6*a*b + b**2\n", + "(3, 2, 4) [0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*a + 2*b\n", + "(3, 3, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0] a**3*b**3\n", + "(3, 3, 1) [0, 0, 0, 0, 0, 3, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 3*a**3*b**2 + 3*a**2*b**3\n", + "(3, 3, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 3, 0, 0, 0, 0, 0, 9, 0] 3*a**3*b + 9*a**2*b**2 + 3*a*b**3\n", + "(3, 3, 3) [0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 9, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3 + 9*a**2*b + 9*a*b**2 + b**3\n", + "(3, 3, 4) [0, 0, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 0] 3*a**2 + 9*a*b + 3*b**2\n", + "(3, 4, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0] a**3*b**4\n", + "(3, 4, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0] 4*a**3*b**3 + 3*a**2*b**4\n", + "(3, 4, 2) [0, 0, 0, 0, 0, 6, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0] 6*a**3*b**2 + 12*a**2*b**3 + 3*a*b**4\n", + "(3, 4, 3) [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 12, 0, 0, 0, 0, 0, 18, 0] 4*a**3*b + 18*a**2*b**2 + 12*a*b**3 + b**4\n", + "(3, 4, 4) [0, 18, 0, 0, 0, 0, 0, 0, 0, 0, 12, 4, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**3 + 12*a**2*b + 18*a*b**2 + 4*b**3\n", + "(4, 0, 0) [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**4\n", + "(4, 0, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a**3\n", + "(4, 0, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0] 6*a**2\n", + "(4, 0, 3) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a\n", + "(4, 0, 4) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 1\n", + "(4, 1, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] a**4*b\n", + "(4, 1, 1) [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**4 + 4*a**3*b\n", + "(4, 1, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a**3 + 6*a**2*b\n", + "(4, 1, 3) [0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0] 6*a**2 + 4*a*b\n", + "(4, 1, 4) [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a + b\n", + "(4, 2, 0) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0] a**4*b**2\n", + "(4, 2, 1) [0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2] 2*a**4*b + 4*a**3*b**2\n", + "(4, 2, 2) [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0] a**4 + 8*a**3*b + 6*a**2*b**2\n", + "(4, 2, 3) [0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a**3 + 12*a**2*b + 4*a*b**2\n", + "(4, 2, 4) [0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 6, 0, 0] 6*a**2 + 8*a*b + b**2\n", + "(4, 3, 0) [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**4*b**3\n", + "(4, 3, 1) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 3, 0, 0, 0, 0] 3*a**4*b**2 + 4*a**3*b**3\n", + "(4, 3, 2) [0, 0, 0, 0, 0, 12, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3] 3*a**4*b + 12*a**3*b**2 + 6*a**2*b**3\n", + "(4, 3, 3) [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 12, 0, 0, 4, 0, 0, 0, 0, 0, 18, 0] a**4 + 12*a**3*b + 18*a**2*b**2 + 4*a*b**3\n", + "(4, 3, 4) [0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 18, 1, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] 4*a**3 + 18*a**2*b + 12*a*b**2 + b**3\n", + "(4, 4, 0) [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] a**4*b**4\n", + "(4, 4, 1) [0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0] 4*a**4*b**3 + 4*a**3*b**4\n", + "(4, 4, 2) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 16, 0, 0, 0, 6, 0, 0, 0, 0] 6*a**4*b**2 + 16*a**3*b**3 + 6*a**2*b**4\n", + "(4, 4, 3) [0, 0, 0, 0, 0, 24, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4] 4*a**4*b + 24*a**3*b**2 + 24*a**2*b**3 + 4*a*b**4\n", + "(4, 4, 4) [0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 16, 0, 0, 16, 0, 0, 0, 0, 0, 36, 0] a**4 + 16*a**3*b + 36*a**2*b**2 + 16*a*b**3 + b**4\n" + ] + } + ], + "source": [ + "LMAX = 5\n", + "all_monoms = set()\n", + "for i in range(LMAX):\n", + " for j in range(LMAX):\n", + " for s in range(LMAX):\n", + " bf = binom_factor(i,j,a,b,s)\n", + " coefs = all_coeffs(S(bf))\n", + " all_monoms = all_monoms.union(set(coefs.keys()))\n", + "all_monoms = tuple(all_monoms)\n", + "n = len(all_monoms)\n", + "display(n, S(all_monoms))\n", + "#monom_map = {k:i for i,k in enumerate(all_monoms)}\n", + "\n", + "weights = np.zeros((LMAX, LMAX, LMAX, n))\n", + "for i in range(LMAX):\n", + " for j in range(LMAX):\n", + " for s in range(LMAX):\n", + " bf = binom_factor(i,j,a,b,s)\n", + " coefs = all_coeffs(S(bf))\n", + " #display(bf, S(coefs))\n", + " val = [coefs.get(monom, 0) for monom in all_monoms]\n", + " print((i,j,s), val, bf)\n", + " weights[i,j,s,:] = val" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", + " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", + " 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,\n", + " 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,\n", + " 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]),\n", + " array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3,\n", + " 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 1, 1, 1, 1, 1, 2, 2,\n", + " 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4,\n", + " 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2,\n", + " 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4,\n", + " 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,\n", + " 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]),\n", + " array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 0, 1, 1, 0, 1, 1, 2, 2, 0, 1, 1,\n", + " 2, 2, 3, 3, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 1, 0, 1, 1, 2, 2, 0, 1,\n", + " 1, 2, 2, 2, 3, 3, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2,\n", + " 2, 3, 3, 3, 4, 4, 4, 0, 1, 2, 0, 1, 1, 2, 2, 3, 3, 0, 1, 1, 2, 2,\n", + " 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 0, 1, 1,\n", + " 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 0, 1, 2, 3, 0, 1, 1, 2, 2, 3, 3,\n", + " 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3,\n", + " 3, 3, 4, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4]),\n", + " array([ 4, 21, 4, 11, 21, 4, 2, 11, 21, 4, 15, 3, 4, 15, 1, 3, 21,\n", + " 4, 15, 17, 1, 11, 3, 21, 4, 15, 19, 2, 17, 1, 11, 3, 21, 4,\n", + " 15, 22, 15, 10, 3, 22, 4, 15, 23, 1, 10, 3, 21, 22, 4, 15, 9,\n", + " 17, 23, 1, 10, 11, 3, 21, 22, 4, 15, 12, 9, 19, 2, 17, 23, 1,\n", + " 10, 11, 3, 21, 22, 13, 22, 15, 14, 10, 13, 3, 22, 4, 15, 5, 14,\n", + " 23, 1, 10, 13, 3, 21, 22, 4, 15, 16, 5, 9, 14, 17, 23, 1, 10,\n", + " 11, 13, 3, 21, 22, 18, 12, 16, 5, 9, 19, 2, 14, 17, 23, 1, 10,\n", + " 11, 13, 7, 13, 22, 15, 24, 7, 14, 10, 13, 3, 22, 4, 15, 20, 5,\n", + " 24, 7, 14, 23, 1, 10, 13, 3, 21, 22, 6, 16, 20, 5, 9, 24, 7,\n", + " 14, 17, 23, 1, 10, 11, 13, 8, 6, 18, 12, 16, 20, 5, 9, 19, 24,\n", + " 2, 7, 14, 17, 23])),\n", + " array([ 1., 1., 2., 1., 3., 3., 1., 4., 6., 4., 1., 1., 1.,\n", + " 1., 1., 2., 1., 2., 1., 1., 3., 1., 3., 3., 3., 1.,\n", + " 1., 1., 4., 6., 4., 4., 6., 4., 1., 1., 2., 1., 2.,\n", + " 1., 1., 2., 1., 2., 2., 4., 1., 1., 2., 2., 1., 2.,\n", + " 3., 6., 3., 1., 6., 3., 1., 3., 2., 1., 4., 2., 1.,\n", + " 8., 6., 12., 4., 4., 8., 6., 1., 1., 3., 3., 1., 3.,\n", + " 1., 3., 3., 1., 3., 1., 2., 3., 3., 6., 1., 6., 1.,\n", + " 3., 2., 3., 1., 3., 3., 3., 3., 9., 9., 9., 1., 1.,\n", + " 9., 3., 3., 1., 3., 4., 6., 12., 3., 1., 4., 12., 18.,\n", + " 18., 12., 4., 1., 1., 4., 6., 4., 1., 1., 4., 6., 4.,\n", + " 4., 6., 1., 4., 1., 4., 2., 1., 8., 6., 4., 12., 4.,\n", + " 8., 1., 6., 1., 4., 3., 12., 6., 3., 1., 12., 4., 18.,\n", + " 12., 18., 1., 4., 1., 4., 4., 6., 16., 6., 24., 24., 4.,\n", + " 4., 1., 1., 16., 16., 36.]))" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inds = np.nonzero(weights)\n", + "inds, weights[inds]\n", + "\n", + "with open(\"../pyscf_ipu/experimental/binom_factor_table.py\", \"w\") as f:\n", + " print(\"# Copyright (c) 2023 Graphcore Ltd. All rights reserved.\", file=f)\n", + " print(\"# AUTOGENERATED from notebooks/binom_factor_table.ipynb\", file=f)\n", + " print(\"# fmt: off\", file=f)\n", + " print(\"# flake8: noqa\", file=f)\n", + " print(\"# isort: skip_file\", file=f)\n", + " print(\"from numpy import array\", file=f)\n", + " print(\"binom_factor_table = \", repr((inds, weights[inds])), file=f)\n", + "\n", + "import pyscf_ipu.experimental.binom_factor_table\n", + "\n", + "pyscf_ipu.experimental.binom_factor_table.binom_factor_table" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jaxipu", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyscf_ipu/experimental/binom_factor_table.py b/pyscf_ipu/experimental/binom_factor_table.py new file mode 100644 index 0000000..6f4041f --- /dev/null +++ b/pyscf_ipu/experimental/binom_factor_table.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +# AUTOGENERATED from notebooks/binom_factor_table.ipynb +# fmt: off +# flake8: noqa +# isort: skip_file +from numpy import array +binom_factor_table = ((array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 0, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, + 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 1, 1, 1, 1, 1, 2, 2, + 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]), array([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 0, 1, 1, 0, 1, 1, 2, 2, 0, 1, 1, + 2, 2, 3, 3, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 1, 0, 1, 1, 2, 2, 0, 1, + 1, 2, 2, 2, 3, 3, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, + 2, 3, 3, 3, 4, 4, 4, 0, 1, 2, 0, 1, 1, 2, 2, 3, 3, 0, 1, 1, 2, 2, + 2, 3, 3, 3, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 0, 1, 1, + 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 0, 1, 2, 3, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, + 3, 3, 4, 4, 4, 4, 0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4]), array([ 4, 21, 4, 11, 21, 4, 2, 11, 21, 4, 15, 3, 4, 15, 1, 3, 21, + 4, 15, 17, 1, 11, 3, 21, 4, 15, 19, 2, 17, 1, 11, 3, 21, 4, + 15, 22, 15, 10, 3, 22, 4, 15, 23, 1, 10, 3, 21, 22, 4, 15, 9, + 17, 23, 1, 10, 11, 3, 21, 22, 4, 15, 12, 9, 19, 2, 17, 23, 1, + 10, 11, 3, 21, 22, 13, 22, 15, 14, 10, 13, 3, 22, 4, 15, 5, 14, + 23, 1, 10, 13, 3, 21, 22, 4, 15, 16, 5, 9, 14, 17, 23, 1, 10, + 11, 13, 3, 21, 22, 18, 12, 16, 5, 9, 19, 2, 14, 17, 23, 1, 10, + 11, 13, 7, 13, 22, 15, 24, 7, 14, 10, 13, 3, 22, 4, 15, 20, 5, + 24, 7, 14, 23, 1, 10, 13, 3, 21, 22, 6, 16, 20, 5, 9, 24, 7, + 14, 17, 23, 1, 10, 11, 13, 8, 6, 18, 12, 16, 20, 5, 9, 19, 24, + 2, 7, 14, 17, 23])), array([ 1., 1., 2., 1., 3., 3., 1., 4., 6., 4., 1., 1., 1., + 1., 1., 2., 1., 2., 1., 1., 3., 1., 3., 3., 3., 1., + 1., 1., 4., 6., 4., 4., 6., 4., 1., 1., 2., 1., 2., + 1., 1., 2., 1., 2., 2., 4., 1., 1., 2., 2., 1., 2., + 3., 6., 3., 1., 6., 3., 1., 3., 2., 1., 4., 2., 1., + 8., 6., 12., 4., 4., 8., 6., 1., 1., 3., 3., 1., 3., + 1., 3., 3., 1., 3., 1., 2., 3., 3., 6., 1., 6., 1., + 3., 2., 3., 1., 3., 3., 3., 3., 9., 9., 9., 1., 1., + 9., 3., 3., 1., 3., 4., 6., 12., 3., 1., 4., 12., 18., + 18., 12., 4., 1., 1., 4., 6., 4., 1., 1., 4., 6., 4., + 4., 6., 1., 4., 1., 4., 2., 1., 8., 6., 4., 12., 4., + 8., 1., 6., 1., 4., 3., 12., 6., 3., 1., 12., 4., 18., + 12., 18., 1., 4., 1., 4., 4., 6., 16., 6., 24., 24., 4., + 4., 1., 1., 16., 16., 36.]))