From fc9f69482958548b6da0338041cf2399552491d0 Mon Sep 17 00:00:00 2001 From: alexandermath Date: Wed, 8 Nov 2023 12:45:53 +0000 Subject: [PATCH 01/22] reimplemented d4ft gd idea. uses newest jax on cpu for simplicity. --- pyscf_ipu/exchange_correlation/b3lyp.py | 33 ++- pyscf_ipu/nanoDFT/gd.py | 315 ++++++++++++++++++++++++ 2 files changed, 347 insertions(+), 1 deletion(-) create mode 100644 pyscf_ipu/nanoDFT/gd.py diff --git a/pyscf_ipu/exchange_correlation/b3lyp.py b/pyscf_ipu/exchange_correlation/b3lyp.py index a5780ee7..731c8a43 100644 --- a/pyscf_ipu/exchange_correlation/b3lyp.py +++ b/pyscf_ipu/exchange_correlation/b3lyp.py @@ -14,6 +14,37 @@ def b3lyp(rho, EPSILON_B3LYP=0): rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) + rho0 = rho.T[:, 0] + #norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP + norms = jnp.linalg.norm(rho[1:]+CLIP_RHO_MIN, axis=0).T**2+EPSILON_B3LYP + + def lda(rho0): return jax.vmap(jax.value_and_grad(lambda x: __lda(x)*0.08)) (rho0) + def vwn(rho0): return jax.vmap(jax.value_and_grad(lambda x: __vwn(x)*0.19)) (rho0) + + # disabled gradient checkpointing + #def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__b88)(rho0, norm)*0.72, (0, 1))) (rho0, norms) + #def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__lyp)(rho0, norm)*0.810, (0, 1))) (rho0, norms) + + def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __b88(rho0, norm)*0.72, (0,1)))(rho0, norms) + def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __lyp(rho0, norm)*0.810, (0,1)))(rho0, norms) + + e_xc_lda, v_rho_lda = jax.jit(lda)(rho0) + e_xc_vwn, v_rho_vwn = jax.jit(vwn)(rho0) + e_xc_b88, (v_rho_b88, v_norm_b88) = jax.jit(b88)(rho0, norms) + e_xc_lyp, (v_rho_lyp, v_norm_lyp) = jax.jit(lyp)(rho0, norms) + + e_xc = e_xc_lda + (e_xc_vwn + e_xc_b88 + e_xc_lyp) / rho0 + #v_xc_rho = v_rho_lda*4*rho0 + v_rho_vwn + v_rho_b88 + v_rho_lyp + #v_xc_norms = v_norm_b88 + v_norm_lyp + + return e_xc#, v_xc_rho, v_xc_norms + + + +def vxc_b3lyp(rho, EPSILON_B3LYP=0): + + rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) + rho0 = rho.T[:, 0] norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP @@ -156,4 +187,4 @@ def plot(rho, b, a, g, grad, vnorm=None, name=""): # b is pyscf a is us ax[2].set_yscale("log") ax[2].set_xscale("log") plt.tight_layout() - plt.savefig("%s_3.jpg"%name) + plt.savefig("%s_3.jpg"%name) \ No newline at end of file diff --git a/pyscf_ipu/nanoDFT/gd.py b/pyscf_ipu/nanoDFT/gd.py new file mode 100644 index 00000000..c21386ed --- /dev/null +++ b/pyscf_ipu/nanoDFT/gd.py @@ -0,0 +1,315 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +# (assumes newest Jax) +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np +import pyscf +import optax +from icecream import ic +from pyscf_ipu.exchange_correlation.b3lyp import b3lyp, vxc_b3lyp +from tqdm import tqdm + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +def dm_energy(density_matrix, state): + #eigvects = L_inv.T @ jnp.linalg.qr( L_inv @ density_matrix @ L_inv.T)[0] + eigvects = state.L_inv.T @ jnp.linalg.eigh(state.L_inv @ density_matrix @ state.L_inv.T)[1] + density_matrix = 2 * (eigvects * state.mask) @ eigvects.T # 'state.mask' does eigvects[:, :n_electrons//2] jit-ably. + E_xc = exchange_correlation(density_matrix, state.grid_AO, state.grid_weights) + diff_JK = get_JK(density_matrix, state.ERI) + return jnp.sum(density_matrix * (state.H_core + diff_JK/2)) + E_xc + state.E_nuc, density_matrix + +def exchange_correlation(density_matrix, grid_AO, grid_weights): + grid_AO_dm = grid_AO[0] @ density_matrix + grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) + mult = grid_AO_dm * grid_AO + rho = jnp.sum(mult, axis=2) + E_xc = b3lyp(rho, EPSILON_B3LYP) + E_xc = jnp.sum(rho[0] * grid_weights * E_xc) + return E_xc + +def get_JK(density_matrix, ERI): + J = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) + K = jnp.einsum('ijkl,jk->il', ERI, density_matrix) + return J - (K / 2 * HYB_B3LYP) + +def nanoDFT(mol_str, opts, pyscf_E): + # Init DFT tensors on CPU using PySCF. + mol = build_mol(mol_str, opts.basis) + pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + + N = mol.nao_nr() + state = init_dft(mol, opts)[0] + target = pyscf_E[-1] + + w = np.eye(N) + np.random.normal(0, 0.01, (N, N)) + + vandg = jax.jit(jax.value_and_grad( dm_energy, has_aux=True), backend=opts.backend) + + # Build initializers for params + adam = optax.adam(opts.lr) + adam_state = adam.init(w) + + pbar = tqdm(range(opts.steps)) + for i in pbar: + (val, density_matrix), grad = vandg(w, state) + updates, adam_state = adam.update(grad, adam_state) + w = optax.apply_updates(w, updates) + pbar.set_description("energy=%.7f [eV] error=%.7f [eV]"%(val*HARTREE_TO_EV, target-val*HARTREE_TO_EV)) + if i == 0: print("") + + V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) + V_xc = (V_xc + V_xc.T)/2 + diff_JK = get_JK(density_matrix, state.ERI) + H = state.H_core + diff_JK + V_xc + mo_energy, mo_coeff = np.linalg.eigh(state.L_inv @ H @ state.L_inv.T) + mo_coeff = state.L_inv.T @ mo_coeff + + return val, (0, mo_energy, mo_coeff, state.grid_coords, state.grid_weights, density_matrix, H) + + +import chex +@chex.dataclass +class IterationState: + E_nuc: np.array + density_matrix: np.array + kinetic: np.array + nuclear: np.array + O: np.array + mask: np.array + L_inv: np.array + L: np.array + H_core: np.array + grid_AO: np.array + grid_weights: np.array + atom_pos: np.array + ERI: np.array + grid_coords: np.array + +def init_dft(mol, opts): + N = mol.nao_nr() + n_electrons_half = mol.nelectron//2 + E_nuc = mol.energy_nuc() + + from pyscf import dft + grids = pyscf.dft.gen_grid.Grids(mol) + grids.level = opts.level + grids.build() + grid_weights = grids.weights + coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' + grid_AO = mol.eval_gto(coord_str, grids.coords, 4) + grid_coords = grids.coords + density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) + + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. + kinetic = mol.intor_symmetric('int1e_kin') + nuclear = mol.intor_symmetric('int1e_nuc') + O = mol.intor_symmetric('int1e_ovlp') + L = np.linalg.cholesky(O) + L_inv = np.linalg.inv(L) + + mask = np.concatenate([np.ones(n_electrons_half), np.zeros(N-n_electrons_half)]) + + ERI = mol.intor("int2e_sph") + + state = IterationState(E_nuc=E_nuc, ERI=ERI, grid_coords=grid_coords, + density_matrix=density_matrix, kinetic=kinetic, + nuclear=nuclear, + O=O, + mask=mask, + H_core=nuclear+kinetic, + L_inv=L_inv, L=L, grid_AO=grid_AO, grid_weights=grid_weights, atom_pos=mol.atom_coords()) + + print("DFT Tensor Summary") + for field_name, field_def in state.__dataclass_fields__.items(): + field_value = getattr(state, field_name) + print(f"{field_name}: {getattr(field_value, 'shape', None)}") + + return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO + + +def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): + # Electronic part of RHF/RKS gradients + dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + dme0 = 2 * (mo_coeff * mask*mo_energy) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + + # Code identical to exchange correlation. + rho = jnp.sum( grid_AO[:1] @ dm0 * grid_AO, axis=2) # (10, grid_size) = (10, 45624) for C6H6. + _, vrho, vgamma = vxc_b3lyp(rho, EPSILON_B3LYP) # (grid_size,) (grid_size,) + V_xc = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma.reshape(1, -1)*rho[1:4]], axis=0) # (4, grid_size) + + vmat = grid_AO[1:4].transpose(0, 2, 1) @ jnp.sum(grid_AO[:4] * jnp.expand_dims(weight * V_xc, axis=2), axis=0) # (3, N, N) + aos = jnp.concatenate([jnp.expand_dims(grid_AO[np.array([1,4,5,6])], 0), jnp.expand_dims(grid_AO[np.array([2,5,7,8])], 0), jnp.expand_dims(grid_AO[np.array([3,6,8,9])], 0)], axis=0) # (3, N, N) + V_xc = - vmat - jnp.transpose(jnp.einsum("snpi,np->spi", aos, weight*V_xc), axes=(0,2,1)) @ grid_AO[0] # (3, 4, grid_size, N) + + vj = - jnp.einsum('sijkl,lk->sij', eri, dm0) # (3, N, N) + vk = - jnp.einsum('sijkl,jk->sil', eri, dm0) # (3, N, N) + vhf = V_xc + vj - vk * .5 * HYB_B3LYP # (3, N, N) + + de = jnp.einsum('lxij,ij->lx', h1aos, dm0) # (natm, 3) + for k, ia in enumerate(range(natm)): + p0, p1 = aoslices[ia][2], aoslices[ia][3] + de = de.at[k].add(jnp.einsum('xij,ij->x', vhf[:, p0:p1], dm0[p0:p1]) * 2) + de = de.at[k].add(-jnp.einsum('xij,ij->x', s1[:, p0:p1], dme0[p0:p1]) * 2) + return de + +def grad_nuc(charges, coords): + # Derivatives of nuclear repulsion energy wrt nuclear coordinates + natm = charges.shape[0] + pairwise_charges = charges.reshape(natm, 1) * charges.reshape(1, natm) # (natm, natm) + pairwise_difference = coords.reshape(1, natm, 3) - coords.reshape(natm, 1, 3) # (natm, natm, 3) + pairwise_distances = jnp.linalg.norm(pairwise_difference, axis=2) ** 3 # (natm, natm) + pairwise_distances = jnp.where(pairwise_distances == 0, jnp.inf, pairwise_distances) # (natm, natm) + all = - pairwise_charges.reshape(natm, natm, 1) * pairwise_difference # (natm, natm, 3) + all = all / pairwise_distances.reshape(natm, natm, 1) # (natm, natm, 3) + all = all.at[jnp.diag_indices(natm)].set(0) # (natm, natm, 3) + return jnp.sum(all, axis=0) # (natm, natm) + +def grad(mol, coords, weight, mo_coeff, mo_energy, dm, H): + # Initialize DFT tensors on CPU using PySCF. + ao = pyscf.dft.numint.NumInt().eval_ao(mol, coords, deriv=2) + eri = mol.intor("int2e_ip1") + s1 = - mol.intor('int1e_ipovlp', comp=3) + kin = - mol.intor('int1e_ipkin', comp=3) + nuc = - mol.intor('int1e_ipnuc', comp=3) + + mask = np.ones(mol.nao_nr()) + mask[mol.nelectron//2:] = 0 + + aoslices = mol.aoslice_by_atom() + h1 = kin + nuc + def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> + _, _, p0, p1 = aoslices[atm_id] + with mol.with_rinv_at_nucleus(atm_id): + vrinv = mol.intor('int1e_iprinv', comp=3) # + vrinv *= -mol.atom_charge(atm_id) + vrinv[:,p0:p1] += h1[:,p0:p1] + return vrinv + vrinv.transpose(0,2,1) + N = h1.shape[1] # (3, N , N) + h1aos = np.zeros((mol.natm, 3, N, N)) + for k, ia in enumerate(range(mol.natm)): + p0, p1 = aoslices[ia,2:] + h1aos[k] = hcore_deriv(ia, aoslices, h1) + + charges = np.zeros((mol.natm)) + coords = np.zeros((mol.natm,3)) + for j in range(mol.natm): + charges[j] = mol.atom_charge(j) + coords[j]= mol.atom_coord(j) + + #_grad_elec = jax.jit(grad_elec, static_argnames=["aoslices", "natm"], backend="cpu") + _grad_elec = grad_elec + _grad_nuc = jax.jit(grad_nuc, backend="cpu") + + return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff, mol, dm, H) + _grad_nuc(charges, coords) + +def pyscf_reference(mol_str, opts): + from pyscf import __config__ + __config__.dft_rks_RKS_grids_level = opts.level + + mol = build_mol(mol_str, opts.basis) + mol.max_cycle = 50 + mf = pyscf.scf.RKS(mol) + mf.max_cycle = 50 + mf.xc = "b3lyp" + mf.diis_space = 8 + pyscf_energies = [] + pyscf_hlgaps = [] + lumo = mol.nelectron//2 + homo = lumo - 1 + def callback(envs): + pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) + hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV + pyscf_hlgaps.append(hl_gap_hartree) + print("\rPYSCF: ", pyscf_energies[-1] , end="") + mf.callback = callback + mf.kernel() + print("") + forces = mf.nuc_grad_method().kernel() + return np.array(pyscf_energies), np.array(pyscf_hlgaps), np.array(forces) + +def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap): + #TODO(HH): rename to match caller variable names + nanoDFT_E = nanoDFT_E*HARTREE_TO_EV + print("pyscf:\t\t%15f"%pyscf_E[-1]) + print("us:\t\t%15f"%nanoDFT_E) + print("diff:\t\t%15f"%np.abs(pyscf_E[-1]-nanoDFT_E)) + print("chemAcc: \t%15f"%0.043) + print("chemAcc/diff: \t%15f"%(0.043/np.abs(pyscf_E[-1]-nanoDFT_E))) + print("") + + # Forces + print() + print("np.max(|nanoDFT_F-PySCF_F|):", np.max(np.abs(nanoDFT_forces-pyscf_forces))) + + norm_X = np.linalg.norm(nanoDFT_forces, axis=1) + norm_Y = np.linalg.norm(pyscf_forces, axis=1) + dot_products = np.sum(nanoDFT_forces * pyscf_forces, axis=1) + cosine_similarity = dot_products / (norm_X * norm_Y) + print("Force cosine similarity:",cosine_similarity) + +def build_mol(mol_str, basis_name): + mol = pyscf.gto.mole.Mole() + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + return mol + +def reference(mol_str, opts): + import pickle + import hashlib + import os + os.makedirs("precomputed", exist_ok=True) + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + print(filename) + if not os.path.exists(filename): + pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) + with open(filename, "wb") as file: + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + else: + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + return pyscf_E, pyscf_hlgap, pyscf_forces + + +if __name__ == "__main__": + #jax.config.FLAGS.jax_platform_name = 'cpu' + import os + import argparse + + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=1e-3) + parser.add_argument('-steps', type=int, default=200) + opts = parser.parse_args() + + # benzene + mol_str = [ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ] + + + mol = build_mol(mol_str, opts.basis) + + ic(mol.nao_nr()) + ic(mol.nelectron) + + pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_str, opts, pyscf_E) + nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) + + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) \ No newline at end of file From c145ef9149fc0e012822930cc5f5cc934382f714 Mon Sep 17 00:00:00 2001 From: alexandermath Date: Mon, 13 Nov 2023 13:38:23 +0000 Subject: [PATCH 02/22] batched gd. --- pyscf_ipu/nanoDFT/batched.py | 648 +++++++++++++++++++++++++++++++++++ 1 file changed, 648 insertions(+) create mode 100644 pyscf_ipu/nanoDFT/batched.py diff --git a/pyscf_ipu/nanoDFT/batched.py b/pyscf_ipu/nanoDFT/batched.py new file mode 100644 index 00000000..bf5f8394 --- /dev/null +++ b/pyscf_ipu/nanoDFT/batched.py @@ -0,0 +1,648 @@ +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np +import pyscf +import optax +from icecream import ic +from pyscf_ipu.exchange_correlation.b3lyp import b3lyp, vxc_b3lyp +from tqdm import tqdm +import time + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +def T(x): return jnp.transpose(x, (0,2,1)) + +# Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. +def dm_energy(W, state, diff_state, normal): + B, N, k = W.shape + L_inv_Q = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(N^2 * num_electrons * batch) instead of O(N^3 * batch)! + density_matrix = 2 * L_inv_Q @ T(L_inv_Q) + E_xc = exchange_correlation(density_matrix, state, diff_state, normal) + diff_JK = JK(density_matrix, state, diff_state, normal) + energies = jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(B, -1), axis=-1) + E_xc + state.E_nuc + return jnp.sum(energies), (energies, E_xc, density_matrix) + +def exchange_correlation(density_matrix, state, diff_state, normal): + B, _, gsize, N = state.grid_AO.shape + if normal: + grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (N, N) = O(B gsize N^2) + rho = jnp.sum(grid_AO_dm * state.grid_AO , axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + else: + def sparse_mult(values, dm): + in_ = dm.take(diff_state.cols, axis=0) + prod = in_*values[:, None] + return jax.ops.segment_sum(prod, diff_state.rows, gsize) + + main = diff_state.main_grid_AO[:1, 0] @ density_matrix # (1, gsize, N) @ (N, N) = O(gsize N^2) + correction = jax.vmap(sparse_mult)(diff_state.sparse_diffs_grid_AO, density_matrix) + grid_AO_dm = (main - correction).reshape(B, 1, gsize, N) + diff = diff_state.main_grid_AO[:1, :] - diff_state.diffs_grid_AO + rho = jnp.sum(grid_AO_dm * diff, axis=3).reshape(B, 4, gsize) + + E_xc = jax.vmap(b3lyp, in_axes=(0,None))(rho, EPSILON_B3LYP).reshape(B, gsize) + E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) + return E_xc + +def JK(density_matrix, state, diff_state, normal): + if normal: + J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) + K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) + diff_JK = J - K / 2 * HYB_B3LYP + else: + from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import sparse_symmetric_einsum + # batched => flops = reads + #diff_JK = jax.vmap(sparse_symmetric_einsum, in_axes=(0, 0, 0))(state.nonzero_distinct_ERI, state.nonzero_indices, density_matrix) + # first + correction_remaining => floats = reads*batch_size + diff_JK = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0))(state.nonzero_distinct_ERI[0], state.nonzero_indices[0], density_matrix) + diff_JK = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0))(diff_state.diffs_ERI, diff_state.indxs, density_matrix) + + return diff_JK + + +def nanoDFT(mol_str, opts, pyscf_E): + # Init DFT tensors on CPU using PySCF. + # Try to re-use grid amongst all points. + state = init_dft(mol_str, opts) + c, w = state.grid_coords, state.grid_weights + print(mol_str[0][1]) + for _ in range(opts.bs-1): + mol_str[0][1] = (mol_str[0][1][0]+0.05, mol_str[0][1][1], mol_str[0][1][2]) + stateB = init_dft(mol_str, opts, c, w) + state = cat(state, stateB) + N = state.N[0] + + summary(state) + + if opts.normal: diff_state = None + else: + main_grid_AO = state.grid_AO[:1] + diffs_grid_AO = main_grid_AO - state.grid_AO + rows, cols = np.nonzero(np.max(diffs_grid_AO[:, 0]!=0, axis=0)) + sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] + + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices[0].reshape(1, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(-1) != 0 + diff_ERIs = diff_ERIs[:, :, nzr] + diff_indxs = diff_indxs[:, nzr] + + diff_state = DiffState(indxs=diff_indxs, + rows=rows, cols=cols, + main_grid_AO=main_grid_AO, sparse_diffs_grid_AO=sparse_diffs_grid_AO, diffs_grid_AO=diffs_grid_AO, diffs_ERI=diff_ERIs) + summary(diff_state) + + if opts.visualize: + pass + + + w = state.init + vandg = jax.jit(jax.value_and_grad( dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", )) + + # Build initializers for params + #adam = optax.adam(lr_schedule) + adam = optax.adabelief(opts.lr) + adam_state = adam.init(w) + + min_val = 0 + min_dm = 0 + + pbar = tqdm(range(opts.steps)) + + (val, _), grad = vandg(w, state, diff_state, opts.normal) + + for i in pbar: + #with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): + (val, (vals, E_xc, density_matrix)), grad = vandg(w, state, diff_state, opts.normal) + updates, adam_state = adam.update(grad, adam_state) + w = optax.apply_updates(w, updates) + #pbar.set_description("energy=%.7f [eV] error=%.7f [eV] (best_error=%.7f[eV])"%(vals*HARTREE_TO_EV, target-vals[0]*HARTREE_TO_EV, target-min_val*HARTREE_TO_EV)) + + if opts.bs == 1: pbar.set_description("error=%.7f [eV] (%.7f %.7f) "%(np.mean(val*HARTREE_TO_EV-state.pyscf_E), val*HARTREE_TO_EV, state.pyscf_E)) + else: + str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(min(5,opts.bs))]) + " [eV]" + #str += "E_xc=" + "".join(["%.7f "%(E_xc[i]*HARTREE_TO_EV) for i in range(opts.bs)]) + " [eV]" + pbar.set_description(str) + if i == 0: print("") + + if val < min_val: + min_val = val + min_dm = density_matrix + + val, density_matrix = min_val, min_dm + + # needs batching + exit() + V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) + V_xc = (V_xc + V_xc.T)/2 + diff_JK = get_JK(density_matrix, state.ERI) + H = state.H_core + diff_JK + V_xc + mo_energy, mo_coeff = np.linalg.eigh(state.L_inv @ H @ state.L_inv.T) + mo_coeff = state.L_inv.T @ mo_coeff + + return val, (0, mo_energy, mo_coeff, state.grid_coords, state.grid_weights, density_matrix, H) + + +import chex +@chex.dataclass +class IterationState: + init: np.array + E_nuc: np.array + mask: np.array + L_inv: np.array + L_inv_T: np.array + H_core: np.array + grid_AO: np.array + grid_weights: np.array + grid_coords: np.array + pyscf_E: np.array + N: int + ERI: np.array + nonzero_distinct_ERI: np.array + nonzero_indices: np.array + +@chex.dataclass +class DiffState: + diffs_ERI: np.array + main_grid_AO: np.array + diffs_grid_AO: np.array + indxs: np.array + sparse_diffs_grid_AO: np.array#jax.experimental.sparse.csr.CSR + rows: np.array + cols: np.array + + +from pyscf.data.elements import charge as elements_proton +from pyscf.dft import gen_grid, radi + +def treutler_atomic_radii_adjust(mol, atomic_radii): + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + a = jnp.array(a) + + def fadjust(i, j, g): + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j] + g1 += g + return g1 + + return fadjust + + +def inter_distance(coords): + rr = np.linalg.norm(coords.reshape(-1, 1, 3) - coords, axis=2) + rr[np.diag_indices(rr.shape[0])] = 0 + return rr + +def original_becke(g): + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + return g + +def get_partition( + mol, + atom_coords, + atom_grids_tab, + radii_adjust=treutler_atomic_radii_adjust, + atomic_radii=radi.BRAGG_RADII, + becke_scheme=original_becke, + concat=True +): + atm_dist = inter_distance(atom_coords) # [natom, natom] + + def gen_grid_partition(coords): + ngrids = coords.shape[0] + dc = coords[None] - atom_coords[:, None] + grid_dist = np.sqrt(np.einsum('ijk,ijk->ij', dc, dc)) # [natom, ngrid] + + ix, jx = np.tril_indices(mol.natm, k=-1) + + natm, ngrid = grid_dist.shape + g_ = -1 / atm_dist.reshape(natm, natm, 1) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) + #g_ = jnp.array(g_) + + def pbecke_g(i, j): + g = g_[i, j] + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j].reshape(-1, 1) + g1 += g + return g1 + + g = pbecke_g(ix, jx) + g = np.copy(becke_scheme(g)) + gp2 = (1+g)/2 + gm2 = (1-g)/2 + + pbecke = jnp.ones((mol.natm, ngrids)) # [natom, ngrid] + pbecke = pbecke.at[ix].mul(gm2) + pbecke = pbecke.at[jx].mul(gp2) + + return pbecke + + coords_all = [] + weights_all = [] + for ia in range(mol.natm): + coords, vol = atom_grids_tab[mol.atom_symbol(ia)] + coords = coords + atom_coords[ia] # [ngrid, 3] + pbecke = gen_grid_partition(coords) # [natom, ngrid] + weights = vol * pbecke[ia] / jnp.sum(pbecke, axis=0) + coords_all.append(coords) + weights_all.append(weights) + + if concat: + coords_all = jnp.vstack(coords_all) + weights_all = jnp.hstack(weights_all) + return coords_all, weights_all + + +class DifferentiableGrids(gen_grid.Grids): + """Differentiable alternative to the original pyscf.gen_grid.Grids.""" + + def build(self, atom_coords) : + mol = self.mol + + atom_grids_tab = self.gen_atomic_grids( + mol, self.atom_grid, self.radi_method, self.level, self.prune + ) + + coords, weights = get_partition( + mol, + atom_coords, + atom_grids_tab, + treutler_atomic_radii_adjust, + self.atomic_radii, + original_becke, + ) + self.coords = coords + self.weights = weights + return coords, weights + + +def grids_from_pyscf_mol( + mol: pyscf.gto.mole.Mole, quad_level: int = 1 +) : + g = gen_grid.Grids(mol) + g.level = quad_level + g.build() + grids = jnp.array(g.coords) + weights = jnp.array(g.weights) + return grids, weights + + +def init_dft(mol_str, opts, _coords=None, _weights=None): + mol = build_mol(mol_str, opts.basis) + pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + + N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) + n_electrons_half = mol.nelectron//2 # 21 for C6H6 + E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. + + from pyscf import dft + #grids = pyscf.dft.gen_grid.Grids(mol) + grids = DifferentiableGrids(mol) + grids.level = opts.level + #grids.build() + grids.build(np.concatenate([np.array(a[1]).reshape(1, 3) for a in mol._atom])) + + grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6 + grid_coords = grids.coords + coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' + grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. + + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. + kinetic = mol.intor_symmetric('int1e_kin') # (N,N) + nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) + O = mol.intor_symmetric('int1e_ovlp') # (N,N) + L = np.linalg.cholesky(O) + L_inv = np.linalg.inv(L) # (N,N) + + + init = np.eye(N)[:, :n_electrons_half] + #I_nxk = init[:, :n_electrons_half] + + mask = np.concatenate([np.ones(n_electrons_half), np.zeros(N-n_electrons_half)]) + if opts.normal: + ERI = mol.intor("int2e_sph") + nonzero_distinct_ERI = np.zeros(1) + nonzero_indices = np.zeros(1) + else: + from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import get_i_j, num_repetitions_fast + eri_threshold = 0 + batches = 1 + nipu = 1 + distinct_ERI = mol.intor("int2e_sph", aosym="s8") + #below_thr = np.abs(distinct_ERI) <= eri_threshold + #distinct_ERI[below_thr] = 0.0 + #ic(distinct_ERI.size, np.sum(below_thr), np.sum(below_thr)/distinct_ERI.size) + #nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) + nonzero_indices = np.arange(distinct_ERI.size)# ]np.nonzero(distinct_ERI)[0].astype(np.uint64) + nonzero_distinct_ERI = distinct_ERI[nonzero_indices]#.astype(np.float32) + + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + nonzero_distinct_ERI = nonzero_distinct_ERI / rep + remainder = nonzero_indices.shape[0] % (nipu*batches) + + if remainder != 0: + ij = np.pad(ij, ((0,nipu*batches-remainder))) + kl = np.pad(kl, ((0,nipu*batches-remainder))) + nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) + + ij = ij.reshape(batches, -1) + kl = kl.reshape(batches, -1) + nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(batches, -1) + + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4) + + #ERI = [nonzero_distinct_ERI, nonzero_indices] + #ERI = ERI + ERI = np.zeros(1) + #ERI = mol.intor("int2e_sph") + + def e(x): return np.expand_dims(x, axis=0) + + + state = IterationState(init = e(init), + E_nuc=e(E_nuc), + ERI=e(ERI), + nonzero_distinct_ERI=e(nonzero_distinct_ERI), + nonzero_indices=e(nonzero_indices), + mask=e(mask), + H_core=e(nuclear+kinetic), + L_inv=e(L_inv), + L_inv_T = e(L_inv.T), + grid_AO=e(grid_AO), + grid_weights=e(grid_weights), + grid_coords=e(grid_coords), + pyscf_E=e(pyscf_E[-1:]), + N=e(mol.nao_nr()), + ) + + + return state + + +def summary(state): + if state is None: return + print("_"*100) + for field_name, field_def in state.__dataclass_fields__.items(): + field_value = getattr(state, field_name) + try: + print("%20s %20s %20s"%(field_name,getattr(field_value, 'shape', None), getattr(field_value, "nbytes", None)/10**9)) + except: + print("BROKE FOR ", field_name) + try: + print(state.pyscf_E[:, -1]) + except: + pass + print("_"*100) + +def cat(dc1, dc2, axis=0): + # Use dictionary comprehension to iterate over the dataclass fields + concatenated_fields = { + field: jnp.concatenate([getattr(dc1, field), getattr(dc2, field)], axis=axis) + for field in dc1.__annotations__ + } + # Create a new dataclass instance with the concatenated fields + return IterationState(**concatenated_fields) + + +def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): + # Electronic part of RHF/RKS gradients + dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + dme0 = 2 * (mo_coeff * mask*mo_energy) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + + # Code identical to exchange correlation. + rho = jnp.sum( grid_AO[:1] @ dm0 * grid_AO, axis=2) # (10, grid_size) = (10, 45624) for C6H6. + _, vrho, vgamma = vxc_b3lyp(rho, EPSILON_B3LYP) # (grid_size,) (grid_size,) + V_xc = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma.reshape(1, -1)*rho[1:4]], axis=0) # (4, grid_size) + + vmat = grid_AO[1:4].transpose(0, 2, 1) @ jnp.sum(grid_AO[:4] * jnp.expand_dims(weight * V_xc, axis=2), axis=0) # (3, N, N) + aos = jnp.concatenate([jnp.expand_dims(grid_AO[np.array([1,4,5,6])], 0), jnp.expand_dims(grid_AO[np.array([2,5,7,8])], 0), jnp.expand_dims(grid_AO[np.array([3,6,8,9])], 0)], axis=0) # (3, N, N) + V_xc = - vmat - jnp.transpose(jnp.einsum("snpi,np->spi", aos, weight*V_xc), axes=(0,2,1)) @ grid_AO[0] # (3, 4, grid_size, N) + + vj = - jnp.einsum('sijkl,lk->sij', eri, dm0) # (3, N, N) + vk = - jnp.einsum('sijkl,jk->sil', eri, dm0) # (3, N, N) + vhf = V_xc + vj - vk * .5 * HYB_B3LYP # (3, N, N) + + de = jnp.einsum('lxij,ij->lx', h1aos, dm0) # (natm, 3) + for k, ia in enumerate(range(natm)): + p0, p1 = aoslices[ia][2], aoslices[ia][3] + de = de.at[k].add(jnp.einsum('xij,ij->x', vhf[:, p0:p1], dm0[p0:p1]) * 2) + de = de.at[k].add(-jnp.einsum('xij,ij->x', s1[:, p0:p1], dme0[p0:p1]) * 2) + return de + +def grad_nuc(charges, coords): + # Derivatives of nuclear repulsion energy wrt nuclear coordinates + natm = charges.shape[0] + pairwise_charges = charges.reshape(natm, 1) * charges.reshape(1, natm) # (natm, natm) + pairwise_difference = coords.reshape(1, natm, 3) - coords.reshape(natm, 1, 3) # (natm, natm, 3) + pairwise_distances = jnp.linalg.norm(pairwise_difference, axis=2) ** 3 # (natm, natm) + pairwise_distances = jnp.where(pairwise_distances == 0, jnp.inf, pairwise_distances) # (natm, natm) + all = - pairwise_charges.reshape(natm, natm, 1) * pairwise_difference # (natm, natm, 3) + all = all / pairwise_distances.reshape(natm, natm, 1) # (natm, natm, 3) + all = all.at[jnp.diag_indices(natm)].set(0) # (natm, natm, 3) + return jnp.sum(all, axis=0) # (natm, natm) + +def grad(mol, coords, weight, mo_coeff, mo_energy, dm, H): + # Initialize DFT tensors on CPU using PySCF. + ao = pyscf.dft.numint.NumInt().eval_ao(mol, coords, deriv=2) + eri = mol.intor("int2e_ip1") + s1 = - mol.intor('int1e_ipovlp', comp=3) + kin = - mol.intor('int1e_ipkin', comp=3) + nuc = - mol.intor('int1e_ipnuc', comp=3) + + mask = np.ones(mol.nao_nr()) + mask[mol.nelectron//2:] = 0 + + aoslices = mol.aoslice_by_atom() + h1 = kin + nuc + def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> + _, _, p0, p1 = aoslices[atm_id] + with mol.with_rinv_at_nucleus(atm_id): + vrinv = mol.intor('int1e_iprinv', comp=3) # + vrinv *= -mol.atom_charge(atm_id) + vrinv[:,p0:p1] += h1[:,p0:p1] + return vrinv + vrinv.transpose(0,2,1) + N = h1.shape[1] # (3, N , N) + h1aos = np.zeros((mol.natm, 3, N, N)) + for k, ia in enumerate(range(mol.natm)): + p0, p1 = aoslices[ia,2:] + h1aos[k] = hcore_deriv(ia, aoslices, h1) + + charges = np.zeros((mol.natm)) + coords = np.zeros((mol.natm,3)) + for j in range(mol.natm): + charges[j] = mol.atom_charge(j) + coords[j]= mol.atom_coord(j) + + #_grad_elec = jax.jit(grad_elec, static_argnames=["aoslices", "natm"], backend="cpu") + _grad_elec = grad_elec + _grad_nuc = jax.jit(grad_nuc, backend="cpu") + + return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff, mol, dm, H) + _grad_nuc(charges, coords) + +def pyscf_reference(mol_str, opts): + from pyscf import __config__ + __config__.dft_rks_RKS_grids_level = opts.level + + mol = build_mol(mol_str, opts.basis) + mol.max_cycle = 50 + mf = pyscf.scf.RKS(mol) + mf.max_cycle = 50 + mf.xc = "b3lyp" + mf.diis_space = 8 + pyscf_energies = [] + pyscf_hlgaps = [] + lumo = mol.nelectron//2 + homo = lumo - 1 + def callback(envs): + pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) + hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV + pyscf_hlgaps.append(hl_gap_hartree) + print("\rPYSCF: ", pyscf_energies[-1] , end="") + mf.callback = callback + mf.kernel() + print("") + forces = mf.nuc_grad_method().kernel() + return np.array(pyscf_energies), np.array(pyscf_hlgaps), np.array(forces) + +def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap): + #TODO(HH): rename to match caller variable names + nanoDFT_E = nanoDFT_E*HARTREE_TO_EV + print("pyscf:\t\t%15f"%pyscf_E[-1]) + print("us:\t\t%15f"%nanoDFT_E) + print("diff:\t\t%15f"%np.abs(pyscf_E[-1]-nanoDFT_E)) + print("chemAcc: \t%15f"%0.043) + print("chemAcc/diff: \t%15f"%(0.043/np.abs(pyscf_E[-1]-nanoDFT_E))) + print("") + + # Forces + print() + print("np.max(|nanoDFT_F-PySCF_F|):", np.max(np.abs(nanoDFT_forces-pyscf_forces))) + + norm_X = np.linalg.norm(nanoDFT_forces, axis=1) + norm_Y = np.linalg.norm(pyscf_forces, axis=1) + dot_products = np.sum(nanoDFT_forces * pyscf_forces, axis=1) + cosine_similarity = dot_products / (norm_X * norm_Y) + print("Force cosine similarity:",cosine_similarity) + +def build_mol(mol_str, basis_name): + mol = pyscf.gto.mole.Mole() + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + return mol + +def reference(mol_str, opts): + import pickle + import hashlib + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + print(filename) + if not os.path.exists(filename): + pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) + with open(filename, "wb") as file: + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + else: + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + return pyscf_E, pyscf_hlgap, pyscf_forces + + +if __name__ == "__main__": + #jax.config.FLAGS.jax_platform_name = 'cpu' + import os + import argparse + + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=1e-3) + parser.add_argument('-steps', type=int, default=200) + parser.add_argument('-bs', type=int, default=2) + + parser.add_argument('-normal', action="store_true") + parser.add_argument('-visualize', action="store_true") + opts = parser.parse_args() + + # benzene + if True: + mol_str = [ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ] + else: + mol_str = [ + ["N", (-1.3289 , 1.0488 , -1.5596)], + ["C", ( 0.1286 , 1.0198 , -1.8261)], + ["C", ( 0.3335 , 0.8585 , -3.3268)], + ["O", (-0.0551 , -0.0282 , -4.0649)], + ["O", ( 1.0668 , 1.8338 , -3.9108)], + ["C", ( 0.8906 , -0.1043 , -1.0999)], + ["H", ( 1.9534 , -0.0888 , -1.4126)], + ["H", ( 0.4975 , -1.0987 , -1.3971)], + ["C", ( 0.8078 , 0.0465 , 0.3677)], + ["C", ( 1.5802 , 0.8809 , 1.1516)], + ["N", ( 1.1567 , 0.7746 , 2.4944)], + ["H", ( 1.7094 , 1.0499 , 3.2650)], + ["C", ( 0.1694 , -0.2350 , 2.5662)], + ["C", (-0.0897 , -0.6721 , 1.2403)], + ["C", (-1.0740 , -1.6418 , 1.0106)], + ["H", (-1.2812 , -1.9849 , -0.0088)], + ["C", (-1.7623 , -2.1470 , 2.0948)], + ["H", (-2.5346 , -2.9080 , 1.9416)], + ["C", (-1.4948 , -1.7069 , 3.4060)], + ["H", (-2.0660 , -2.1385 , 4.2348)], + ["C", (-0.5337 , -0.7507 , 3.6638)], + ["H", (-0.3249 , -0.4086 , 4.6819)], + ["H", ( 2.3719 , 1.5631 , 0.8380)], + ["H", (-1.4726 , 1.2086 , -0.5841)], + ["H", (-1.7404 , 0.1740 , -1.8129)], + ["H", ( 0.5299 , 2.0096 , -1.4901)], + ["H", ( 1.1361 , 1.6737 , -4.8470)], + ] + + + #pos = [np.array(a[1]).reshape(1, 1) for a in mol_str] + #distances = map(lambda x: np.linalg.norm(np.array(x[0]) - np.array(x[1])), combinations(coords, 2)) + #return min(distances) + + + + + mol = build_mol(mol_str, opts.basis) + ic(mol.nao_nr()) + ic(mol.nelectron) + + pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_str, opts, pyscf_E) + nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) + + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) From 9b203d2ecf41755d291ddd1b12965973be68ed58 Mon Sep 17 00:00:00 2001 From: AlexanderMath Date: Wed, 15 Nov 2023 14:45:14 +0000 Subject: [PATCH 03/22] Update gd.py draft of annotation/simplifications with awf --- pyscf_ipu/nanoDFT/gd.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/pyscf_ipu/nanoDFT/gd.py b/pyscf_ipu/nanoDFT/gd.py index c21386ed..e906af72 100644 --- a/pyscf_ipu/nanoDFT/gd.py +++ b/pyscf_ipu/nanoDFT/gd.py @@ -11,27 +11,27 @@ from tqdm import tqdm HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 - -def dm_energy(density_matrix, state): - #eigvects = L_inv.T @ jnp.linalg.qr( L_inv @ density_matrix @ L_inv.T)[0] - eigvects = state.L_inv.T @ jnp.linalg.eigh(state.L_inv @ density_matrix @ state.L_inv.T)[1] - density_matrix = 2 * (eigvects * state.mask) @ eigvects.T # 'state.mask' does eigvects[:, :n_electrons//2] jit-ably. - E_xc = exchange_correlation(density_matrix, state.grid_AO, state.grid_weights) - diff_JK = get_JK(density_matrix, state.ERI) - return jnp.sum(density_matrix * (state.H_core + diff_JK/2)) + E_xc + state.E_nuc, density_matrix - -def exchange_correlation(density_matrix, grid_AO, grid_weights): - grid_AO_dm = grid_AO[0] @ density_matrix - grid_AO_dm = jnp.expand_dims(grid_AO_dm, axis=0) - mult = grid_AO_dm * grid_AO - rho = jnp.sum(mult, axis=2) - E_xc = b3lyp(rho, EPSILON_B3LYP) - E_xc = jnp.sum(rho[0] * grid_weights * E_xc) +def orth(x): return jnp.linalg.qr(x)[0] + +def dm_energy(weights: NxK, state): + eigvects: NxK = state.L_inv.T @ orth(weights) + density_matrix: NxN = 2 * eigvects @ eigvects.T + E_xc: float = exchange_correlation(density_matrix, state.grid_AO, state.grid_weights) + diff_JK: NxN = get_JK(density_matrix, state.ERI) + energy: float = jnp.sum(density_matrix * (state.H_core + diff_JK/2)) + E_xc + state.E_nuc + return energy, density_matrix + +def exchange_correlation(density_matrix: NxN, grid_AO: _4xGsizexN, grid_weights: gsize): + grid_AO_dm: _1xGsizexN = jnp.expand_dims(grid_AO[0] @ density_matrix) # O(gsize N^2) flops and gsizeN reads. + mult: _4xGsizexN = grid_AO_dm * grid_AO + rho: _4xGsize = jnp.sum(mult, axis=2) + E_xc: Gsize = b3lyp(rho, EPSILON_B3LYP) + E_xc: float = jnp.sum(rho[0] * grid_weights * E_xc) return E_xc -def get_JK(density_matrix, ERI): - J = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) - K = jnp.einsum('ijkl,jk->il', ERI, density_matrix) +def get_JK(density_matrix: NxN, ERI: NxNxNxN): + J: (N, N) = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) + K: (N, N) = jnp.einsum('ijkl,jk->il', ERI, density_matrix) return J - (K / 2 * HYB_B3LYP) def nanoDFT(mol_str, opts, pyscf_E): @@ -312,4 +312,4 @@ def reference(mol_str, opts): nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_str, opts, pyscf_E) nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) - print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) \ No newline at end of file + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) From f9f772d9bd261ade59033fe2426d9af3aad493ed Mon Sep 17 00:00:00 2001 From: Alex Mathiasen Date: Sun, 3 Dec 2023 16:10:50 +0000 Subject: [PATCH 04/22] train transformer with dft loss. --- pyscf_ipu/direct/alchemy/download.sh | 12 + pyscf_ipu/direct/alchemy/reproduce_pyscf.py | 18 + pyscf_ipu/direct/alchemy/to_pickle.py | 81 ++ pyscf_ipu/direct/exchange_correlation/LICENSE | 373 ++++++++ .../direct/exchange_correlation/b3lyp.py | 138 +++ pyscf_ipu/direct/sparse_symmetric_ERI.py | 127 +++ pyscf_ipu/direct/train.py | 896 ++++++++++++++++++ pyscf_ipu/direct/transformer.py | 193 ++++ 8 files changed, 1838 insertions(+) create mode 100755 pyscf_ipu/direct/alchemy/download.sh create mode 100644 pyscf_ipu/direct/alchemy/reproduce_pyscf.py create mode 100644 pyscf_ipu/direct/alchemy/to_pickle.py create mode 100644 pyscf_ipu/direct/exchange_correlation/LICENSE create mode 100644 pyscf_ipu/direct/exchange_correlation/b3lyp.py create mode 100644 pyscf_ipu/direct/sparse_symmetric_ERI.py create mode 100644 pyscf_ipu/direct/train.py create mode 100644 pyscf_ipu/direct/transformer.py diff --git a/pyscf_ipu/direct/alchemy/download.sh b/pyscf_ipu/direct/alchemy/download.sh new file mode 100755 index 00000000..2d9d4913 --- /dev/null +++ b/pyscf_ipu/direct/alchemy/download.sh @@ -0,0 +1,12 @@ +wget -O dev.zip https://alchemy.tencent.com/data/dev_v20190730.zip +wget -O valid.zip https://alchemy.tencent.com/data/valid_v20190730.zip +wget -O test.zip https://alchemy.tencent.com/data/test_v20190730.zip + +unzip dev.zip +unzip valid.zip +unzip test.zip + +wget -O alchemy.zip https://alchemy.tencent.com/data/alchemy-v20191129.zip +unzip alchemy.zip +mv Alchemy-v20191129/* . +rmdir Alchemy-v20191129 \ No newline at end of file diff --git a/pyscf_ipu/direct/alchemy/reproduce_pyscf.py b/pyscf_ipu/direct/alchemy/reproduce_pyscf.py new file mode 100644 index 00000000..262d1a37 --- /dev/null +++ b/pyscf_ipu/direct/alchemy/reproduce_pyscf.py @@ -0,0 +1,18 @@ +import pandas as pd +import pyscf +from pyscf import __config__ +__config__.dft_rks_RKS_grids_level = 3 +from pyscf import dft +import numpy as np + +df = pd.read_pickle("atom_9.pickle") +mol = pyscf.gto.Mole(atom=df["pyscf"].values[0], basis="6-31G(2df,p)", spin=0) +mol.build() +mf = pyscf.dft.RKS(mol) +mf.verbose = 4 +mf.xc = 'B3LYP5' # pyscf changed b3lyp from vwn5 to vwn3 to agree with gaussian. +print(mf.kernel()) +print(df["energy"].values[0]) +print(df["homo"].values[0]) +print(df["lumo"].values[0]) +print(df["gap"].values[0]) \ No newline at end of file diff --git a/pyscf_ipu/direct/alchemy/to_pickle.py b/pyscf_ipu/direct/alchemy/to_pickle.py new file mode 100644 index 00000000..9735f299 --- /dev/null +++ b/pyscf_ipu/direct/alchemy/to_pickle.py @@ -0,0 +1,81 @@ +import pandas as pd +import os +from rdkit import Chem +import numpy as np +import pyscf +from natsort import natsorted +from tqdm import tqdm + +# we test loading by reproducing labels with pyscf. +# (instead of checking e.g. np_to_sdf) + +def sdf_to_np(filename): + s = open('%s'%filename, 'r').read() + lines = s.split('V2000')[1].split('\n')[1:-1] + lines = [[a for a in line.split(' ') if a != '' ][:4] for line in lines if len(line)>35] + lines = [[line[3], (float(line[0]), float(line[1]), float(line[2]))] for line in lines ] + atom_str = [line[0] for line in lines] + atom_pos = np.concatenate([np.array(line[1:]).reshape(1, -1) for line in lines ] ) + return atom_str, atom_pos + +def np_to_pyscf(str, xyz): + atom_list = [] + + for i, atom_type in enumerate(str): + x, y, z = xyz[i] + atom_list.append([atom_type, (x, y, z)]) + + return atom_list + +def spin(pyscf_format): + try: + mol = pyscf.gto.Mole(atom=pyscf_format, basis="6-31g(2df,p)") + mol.build() + return 0 + except: + mol = pyscf.gto.Mole(atom=pyscf_format, basis="6-31g(2df,p)", spin=1) + mol.build() + return 1 + +def nao(pyscf_format, spin): + mol = pyscf.gto.Mole(atom=pyscf_format, basis="6-31g(2df,p)", spin=spin) + mol.build() + return mol.nao_nr() + +# load all labels in the final 200k version +df = pd.read_csv("final_version.csv") + +# add info on train/test/val split +train = pd.read_csv("dev/dev_target.csv") +valid = pd.read_csv("valid/valid_target.csv") +test = pd.DataFrame({"gdb_idx": os.listdir("test/sdf/atom_11") + os.listdir("test/sdf/atom_12")}) +df["train"] = df["gdb_idx"].isin(train["gdb_idx"]) +df["test"] = df["gdb_idx"].isin(test["gdb_idx"]) +df["valid"] = df["gdb_idx"].isin(valid["gdb_idx"]) + +# alchemy computes u0 = results['E_0K' ] = (E0 + ZPE, 'Eh'), so need to subtract zpve +# https://github.com/tencent-alchemy/alchemy-pyscf/blob/fa4f7ff46be308302ba1e95754701142b6c4bf7f/alchemypyscf/thermo.py#L215 +df["energy"] = df["U0\n(Ha, internal energy at 0 K)"] - df["zpve\n(Ha, zero point vibrational energy)"] + +for folder in ["atom_9", "atom_10", "atom_11", "atom_12"]: + files = natsorted(os.listdir(folder)) + + strs, xyzs, pyscfs, gdb_idxs, naos, spins = [], [], [], [], [], [] + + for f in tqdm(files): + try: + str, xyz = sdf_to_np("%s/%s"%(folder, f)) + strs.append(str) + xyzs.append(xyz) + pyscfs.append( np_to_pyscf(str, xyz) ) + gdb_idxs.append(int(f.replace(".sdf", ""))) + spins.append(spin(pyscfs[-1])) + naos.append(-1)#nao(pyscfs[-1], spins[-1])) + except: + print("broke %s"%f) + + df2 = pd.DataFrame({"gdb_idx": gdb_idxs, "pyscf": pyscfs, "str": strs, "xyz": xyzs, "nao": naos, "spin": spins}) + merged = pd.merge(df, df2, on="gdb_idx", how="inner") + + merged.to_pickle("%s.pickle"%folder) + break \ No newline at end of file diff --git a/pyscf_ipu/direct/exchange_correlation/LICENSE b/pyscf_ipu/direct/exchange_correlation/LICENSE new file mode 100644 index 00000000..fa0086a9 --- /dev/null +++ b/pyscf_ipu/direct/exchange_correlation/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. \ No newline at end of file diff --git a/pyscf_ipu/direct/exchange_correlation/b3lyp.py b/pyscf_ipu/direct/exchange_correlation/b3lyp.py new file mode 100644 index 00000000..827236d3 --- /dev/null +++ b/pyscf_ipu/direct/exchange_correlation/b3lyp.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import jax.numpy as jnp +import jax +import numpy as np + +def __lyp(n, gnn): + # precompute + A = 0.04918 + B = 0.132 + C = 0.2533 + Dd = 0.349 + CF = 0.3 * (3.0 * np.pi * np.pi) ** (2.0 / 3.0) + c0 = 2.0 ** (11.0 / 3.0) * (1/2)**(8/3) + c1 = (1/3 + 1/8)*4 + + # actual compute + log_n = jnp.log(n) + icbrtn = jnp.exp(log_n * (-1.0 / 3.0) ) + + P = 1.0 / (1.0 + Dd * icbrtn) + omega = jnp.exp(-C * icbrtn) * P + delta = icbrtn * (C + Dd * P) + + n_five_three = jnp.exp(log_n*(-5/3)) + + result = -A * ( + n * P + + B + * omega + * 1/ 4 *( + 2 * CF * n * c0+ + gnn * (60 - 14.0 * delta) /36 * n_five_three + - gnn *c1 * n_five_three + ) + ) + + return result + +def __vwn(n): + # Precompute stuff in np.float64 + p = np.array( [-0.10498, 0.0621813817393097900698817274255, 3.72744, 12.9352]) + f = p[0] * p[2] / (p[0] * p[0] + p[0] * p[2] + p[3]) - 1.0 + f_inv_p1 = 1/f+1 + f_2 = f * 0.5 + sqrt = np.sqrt(4.0 * p[3] - p[2] * p[2]) + precompute = p[2] * ( 1.0 / sqrt + - p[0] + / ( + (p[0] * p[0] + p[0] * p[2] + p[3]) + * sqrt + / (p[2] + 2.0 * p[0]) + ) + ) + log_s_c = np.log( 3.0 /(4*np.pi) ) / 6 + + # Below cast to same dtype as input (allow easier comparison between f32/f64). + dtype = n.dtype + p = p.astype(dtype) + f = f.astype(dtype) + f_inv_p1 = (f_inv_p1).astype(dtype) + f_2 = f_2.astype(dtype) + sqrt = sqrt.astype(dtype) + precompute = precompute.astype(dtype) + log_s_c =log_s_c.astype(dtype) + + # compute stuff that depends on n + log_s = - jnp.log(n) / 6 + log_s_c + s_2 = jnp.exp( log_s *2) + s = jnp.exp( log_s ) + z = sqrt / (2.0 * s + p[2]) + + result = n * p[1] * ( + log_s + #+ f * jnp.log( jnp.sqrt( s_2 + p[2] * s + p[3] ) / (s-p[0])**(1/f+1) ) # problem with float, 1/f+1 was done in np which automatically sticks to float64 + + f * jnp.log( jnp.sqrt( s_2 + p[2] * s + p[3] ) / (s-p[0])**(f_inv_p1) ) + + precompute * jnp.arctan(z) + + ) + + return result + +def __b88(a, gaa): + # precompute + c1 = (4.0 / 3.0) + c2 = (-8.0 / 3.0) + c3 = (-3.0 / 4.0) * (6.0 / np.pi) ** (1.0 / 3.0) * 2 + d = 0.0042 + d2 = d * 2. + d12 = d *12. + + # actual compute + log_a = jnp.log(a/2) + na43 = jnp.exp(log_a * c1) + chi2 = gaa / 4* jnp.exp(log_a * c2 ) + chi = jnp.exp(jnp.log( chi2 ) / 2 ) + b88 = -(d * na43 * chi2) / (1.0 + 6*d * chi * jnp.arcsinh(chi)) *2 + slaterx_a = c3 * na43 + return slaterx_a + b88 + +def __lda(rho): return -jnp.exp(1/3*jnp.log(rho) - 0.30305460484554375) + +CLIP_RHO_MIN = 1e-10 +CLIP_RHO_MAX = 1e15 + +def b3lyp(rho, EPSILON_B3LYP=0): + rho0 = jnp.clip(rho[0], CLIP_RHO_MIN, CLIP_RHO_MAX) + norms = jnp.linalg.norm(rho[1:4]*2+CLIP_RHO_MIN, axis=0).T**2+EPSILON_B3LYP + return __lda(rho0)*0.08 + (__vwn(rho0)*0.19 + __b88(rho0, norms)*0.72 + __lyp(rho0, norms)*0.81) / rho0 + + +def vxc_b3lyp(rho, EPSILON_B3LYP=0): + rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) + + rho0 = rho.T[:, 0] + norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP + + def lda(rho0): return jax.vmap(jax.value_and_grad(lambda x: __lda(x)*0.08)) (rho0) + def vwn(rho0): return jax.vmap(jax.value_and_grad(lambda x: __vwn(x)*0.19)) (rho0) + + # disabled gradient checkpointing + #def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__b88)(rho0, norm)*0.72, (0, 1))) (rho0, norms) + #def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__lyp)(rho0, norm)*0.810, (0, 1))) (rho0, norms) + + def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __b88(rho0, norm)*0.72, (0,1)))(rho0, norms) + def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __lyp(rho0, norm)*0.810, (0,1)))(rho0, norms) + + e_xc_lda, v_rho_lda = jax.jit(lda)(rho0) + e_xc_vwn, v_rho_vwn = jax.jit(vwn)(rho0) + e_xc_b88, (v_rho_b88, v_norm_b88) = jax.jit(b88)(rho0, norms) + e_xc_lyp, (v_rho_lyp, v_norm_lyp) = jax.jit(lyp)(rho0, norms) + + e_xc = e_xc_lda + (e_xc_vwn + e_xc_b88 + e_xc_lyp) / rho0 + v_xc_rho = v_rho_lda*4*rho0 + v_rho_vwn + v_rho_b88 + v_rho_lyp + v_xc_norms = v_norm_b88 + v_norm_lyp + + return e_xc, v_xc_rho, v_xc_norms + + diff --git a/pyscf_ipu/direct/sparse_symmetric_ERI.py b/pyscf_ipu/direct/sparse_symmetric_ERI.py new file mode 100644 index 00000000..3efea154 --- /dev/null +++ b/pyscf_ipu/direct/sparse_symmetric_ERI.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import pyscf +import numpy as np +import jax +import jax.numpy as jnp +from functools import partial +from icecream import ic +HYB_B3LYP = 0.2 + +@partial(jax.jit, backend="cpu") +def get_i_j(val): + import jax.numpy as np + i = (np.sqrt(1 + 8*val.astype(np.uint64)) - 1)//2 # no need for floor, integer division acts as floor. + j = (((val - i) - (i**2 - val))//2) + return i, j + +def ijkl(value, symmetry, N, f): + import jax.numpy as np + i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) + return f(i,j,k,l,symmetry,N) +ijkl = jax.vmap(ijkl, in_axes=(0, None, None, None)) + + +@partial(jax.jit, backend="cpu") +def num_repetitions_fast(ij, kl): + import jax.numpy as np + i, j = get_i_j(ij) + k, l = get_i_j(kl) + + # compute: repetitions = 2^((i==j) + (k==l) + (k==i and l==j or k==j and l==i)) + repetitions = 2**( + np.equal(i,j).astype(np.uint64) + + np.equal(k,l).astype(np.uint64) + + (1 - ((1 - np.equal(k,i) * np.equal(l,j)) * + (1- np.equal(k,j) * np.equal(l,i))).astype(np.uint64)) + ) + return repetitions + +indices_func = lambda i,j,k,l,symmetry,N: jnp.array([i*N+j, j*N+i, i*N+j, j*N+i, k*N+l, l*N+k, k*N+l, l*N+k, + k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i, + k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k, + i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] + +def _indices_func(i, j, k, l, symmetry, N): + if symmetry == 0: return i * N + j + elif symmetry == 1: return j * N + i + elif symmetry == 2: return i * N + j + elif symmetry == 3: return j * N + i + elif symmetry == 4: return k * N + l + elif symmetry == 5: return l * N + k + elif symmetry == 6: return k * N + l + elif symmetry == 7: return l * N + k + elif symmetry == 8 or symmetry == 9: return k * N + l + elif symmetry == 10 or symmetry == 11: return l * N + k + elif symmetry == 12 or symmetry == 13: return i * N + j + elif symmetry == 14 or symmetry == 15: return j * N + i + elif symmetry == 16: return k * N + j + elif symmetry == 17: return k * N + i + elif symmetry == 18: return l * N + j + elif symmetry == 19: return l * N + i + elif symmetry == 20: return i * N + l + elif symmetry == 21: return i * N + k + elif symmetry == 22: return j * N + l + elif symmetry == 23: return j * N + k + elif symmetry == 24: return i * N + l #j*N+l, i*N+k, j*N+k, + elif symmetry == 25: return j*N+l + elif symmetry == 26: return i*N+k + elif symmetry == 27: return j*N+k + elif symmetry == 28: return k * N + j + elif symmetry == 29: return l * N + j + elif symmetry == 30: return k * N + i + elif symmetry == 31: return l * N + i + + +def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm): + dm = dm.reshape(-1) + diff_JK = jnp.zeros(dm.shape) + N = int(np.sqrt(dm.shape[0])) + + dnums = jax.lax.GatherDimensionNumbers( + offset_dims=(), + collapsed_slice_dims=(0,), + start_index_map=(0,)) + scatter_dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)) + + def iteration(symmetry, vals): + diff_JK = vals + is_K_matrix = (symmetry >= 8) + + def sequentialized_iter(i, vals): + # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) + # Trade-off: Using one function leads to smaller always-live memory. + diff_JK = vals + indices = nonzero_indices[i].astype(np.int32) + eris = nonzero_distinct_ERI[i] + + dm_indices = ijkl(indices, symmetry+is_K_matrix*8, N, indices_func).reshape(-1, 1) + #dm_values = jnp.take(dm, dm_indices, axis=0)[:, 0] # for our special case the 50 lines of code reduces to the one line below. + dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) + + dm_values = dm_values * eris + + ss_indices = ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) + # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. + diff_JK = diff_JK + jax.lax.scatter_add(jnp.zeros((N**2,)), + ss_indices, dm_values, + scatter_dnums, indices_are_sorted=True, unique_indices=False, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ + *(-HYB_B3LYP/2)**is_K_matrix + + return diff_JK + + batches = nonzero_indices.shape[0] + diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) + #for i in range(batches): + # diff_JK = sequentialized_iter(i, diff_JK) + return diff_JK + + diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + #for i in range(0, 16): + # diff_JK = iteration(i, diff_JK) + #diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + #return jax.lax.psum(diff_JK, axis_name="p") + return diff_JK.reshape(N, N) + diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py new file mode 100644 index 00000000..18b99e5c --- /dev/null +++ b/pyscf_ipu/direct/train.py @@ -0,0 +1,896 @@ +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np +import pyscf +import optax +from icecream import ic +from exchange_correlation.b3lyp import b3lyp, vxc_b3lyp +from tqdm import tqdm +import time +from transformer import transformer, transformer_init +import pandas as pd + +cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 + +def T(x): return jnp.transpose(x, (0,2,1)) + +B, BxNxN, BxNxK = None, None, None + +# Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. +def dm_energy(W: BxNxK, state, normal, nn): + if nn: + W = jnp.mean(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, W, state.ao_types, state.pos, state.H_core) , axis=1) + W = W @ state.init + + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + density_matrix: BxNxN = 2 * L_inv_Q @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO + E_xc: B = exchange_correlation(density_matrix, state, normal) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) + energy: float = jnp.sum(energies) + return energy, (energies, E_xc, density_matrix) + +def sparse_mult(values, dm, state, gsize): + in_ = dm.take(state.cols, axis=0) + prod = in_*values[:, None] + return jax.ops.segment_sum(prod, state.rows, gsize) + +def exchange_correlation(density_matrix, state, normal): + _, _, gsize, N = state.grid_AO.shape + B = density_matrix.shape[0] + if normal: + grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (B, N, N) = O(B gsize N^2) + rho = jnp.sum(grid_AO_dm * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + else: + main = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs + correction = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) + + # subtract before/after einsum. + if True: + grid_AO_dm = (main - correction).reshape(B, 1, gsize, N) # (B * gsize * N) + rho = jnp.einsum("bpij,bqij->bpi", state.grid_AO, grid_AO_dm) + else: + rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) + rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) + rho = rho_a - rho_b + + E_xc = jax.vmap(b3lyp, in_axes=(0,None))(rho, EPSILON_B3LYP).reshape(B, gsize) + E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) + return E_xc + +def JK(density_matrix, state, normal): + if normal: + J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) + K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) + diff_JK = J - K / 2 * HYB_B3LYP + else: + from sparse_symmetric_ERI import sparse_symmetric_einsum + # batched => flops = reads + #diff_JK = jax.vmap(sparse_symmetric_einsum, in_axes=(0, 0, 0))(state.nonzero_distinct_ERI, state.nonzero_indices, density_matrix) + # first + correction_remaining => floats = reads*batch_size + diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0))(state.nonzero_distinct_ERI[0], state.nonzero_indices[0], density_matrix) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0))(state.diffs_ERI, state.indxs, density_matrix) + + return diff_JK + +def nao(atom, basis): + m = pyscf.gto.Mole(atom='%s 0 0 0; %s 0 0 1;'%(atom, atom), basis=basis) + m.build() + return m.nao_nr()//2 + +def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): + t0 = time.time() + state = init_dft(mol_str, opts, do_pyscf=do_pyscf) + c, w = state.grid_coords, state.grid_weights + + + np.random.seed(42) + p = np.array(mol_str[0][1]) + states = [state] + for i in tqdm(range(bs-1)): + x = p + np.random.normal(0, opts.wiggle_var, (3)) + mol_str[wiggle_num][1] = (x[0], x[1], x[2]) + + # when profiling create fake molecule to skip waiting + if i == 0 or not opts.prof: + stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 5) + + states.append(stateB) + + state = cats(states) + N = state.N[0] + + # Compute ERI sparsity. + nonzero = [] + for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): + abs = np.abs(e) + indxs = abs < opts.eri_threshold #1e-10 + e[indxs] = 0 + nonzero.append(np.nonzero(e)[0]) + + # Merge nonzero indices and prepare (ij, kl). + # rep is the number of repetitions we include in the sparse representation. + # TODO: the union1d should include all nonzero, not just first. + nonzero_indices = np.union1d(nonzero[0], nonzero[1]) + from sparse_symmetric_ERI import get_i_j, num_repetitions_fast + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + + batches = 8 + es = [] + for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): + nonzero_distinct_ERI = e[nonzero_indices] / rep + remainder = nonzero_indices.shape[0] % (batches) + if remainder != 0: nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,batches-remainder)) + + nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(batches, -1) + es.append(nonzero_distinct_ERI) + + state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) + + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + + if remainder != 0: + i = np.pad(i, ((0,batches-remainder))) + j = np.pad(j, ((0,batches-remainder))) + k = np.pad(k, ((0,batches-remainder))) + l = np.pad(l, ((0,batches-remainder))) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int16) + + state.nonzero_indices = nonzero_indices + + if opts.normal: diff_state = None + else: + main_grid_AO = state.grid_AO[:1] + diffs_grid_AO = main_grid_AO - state.grid_AO + rows, cols = np.nonzero(np.max(diffs_grid_AO[:, 0]!=0, axis=0)) + sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] + + # use the same sparsity pattern across a batch. + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 + + diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) + diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) + + remainder = np.sum(nzr) % batches + if remainder != 0: + diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) + diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) + + diff_ERIs = diff_ERIs.reshape(bs, batches, -1) + diff_indxs = diff_indxs.reshape(batches, -1, 4) + + state.indxs=diff_indxs + state.rows=rows + state.cols=cols + + state.main_grid_AO=main_grid_AO[:1, 0] + + state.sparse_diffs_grid_AO = sparse_diffs_grid_AO + state.diffs_grid_AO = diffs_grid_AO + state.diffs_ERI=diff_ERIs + + #state.grid_AO = state.grid_AO[:1] + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[:1] + state.nonzero_indices = np.expand_dims(state.nonzero_indices, axis=0) + + indxs = np.abs(state.nonzero_distinct_ERI ) > 1e-9 + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[indxs] + state.nonzero_indices = state.nonzero_indices[indxs] + remainder = state.nonzero_indices.shape[0] % batches + + if remainder != 0: + state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, (0,batches-remainder)) + state.nonzero_indices = np.pad(state.nonzero_indices, ((0,batches-remainder), (0,0))) + + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) + state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) + + print("batch: ", time.time()-t0) + return state + + +# todo: wiggle asynch each 10 steps or so. +def wiggle(state): # idea: keep bs=64 state, but add say another 64 wiggles which change. + pass + + +def nanoDFT(mol_str, opts): + print() + # Initialize validation set. + # This consists of DFT tensors initialized with PySCF/CPU. + np.random.seed(42) + + + # Step 1. + # Take benzene; randomly sample first carbon. + # Get model generalize to new randomly sampled carbons. + # + # Step 2. + # Change benzene to have CNOF. + # Do single atom wiggles on CNOF, get model to generalize to simultaneous wiggles. + # + # Step 3/4. Scale grid-size and basis set. + + # Data Creation / Data loader + # [ ] Move to dataloader; whenever a new data point is ready, use that. + # [ ] Consider pre-compute/load. + # [ ] Consider 'wiggle' function which generate new wiggle datapoints in state. + val_state = batched_state(mol_str[0], opts, opts.val_bs, do_pyscf=True) + states = [batched_state(mol_str[0], opts, opts.bs, do_pyscf=True)] + [batched_state(mol_str[i], opts, opts.bs, do_pyscf=False) for i in range(opts.states-1)] + + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + + global cfg + '''Model ViT model embedding #heads #layers #params training throughput + dimension resolution (im/sec) + DeiT-Ti N/A 192 3 12 5M 224 2536 + DeiT-S N/A 384 6 12 22M 224 940 + DeiT-B ViT-B 768 12 12 86M 224 292 + ''' + if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 + elif opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 + elif opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 + + if opts.nn: + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + + if opts.nn: + #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer + def custom_schedule(step_num, d_model, warmup_steps): + arg1 = step_num ** -0.5 + arg2 = step_num * warmup_steps ** -1.5 + return d_model ** -0.5 * min(arg1, arg2) + + optimizer = optax.adam(learning_rate=lambda step: custom_schedule(step, d_model=d_model, warmup_steps=4000), + b1=0.9, b2=0.98, eps=1e-9) + + adam = optax.adam(opts.lr) + w = params + else: + w = states[0].init + adam = optax.adabelief(opts.lr) + + vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) + valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) + + # Build initializers for params + adam_state = adam.init(w) + + min_val = 0 + min_dm = 0 + + pbar = tqdm(range(opts.steps)) + + summary(states[0]) + + mins = np.ones(opts.bs) * 1e6 + if opts.wandb: + import wandb + wandb.init(project='ndft') + wandb.log({'total_params': total_params}) + + print("jitting...") + t0 = time.time() + (val, (vals, E_xc, density_matrix)), grad = vandg(w, states[0], opts.normal, opts.nn) + print("done!", time.time()-t0) + + + def update(w, state, adam_state): + (val, (vals, E_xc, density_matrix)), grad = vandg(w, state, opts.normal, opts.nn) + updates, adam_state = adam.update(grad, adam_state) + w = optax.apply_updates(w, updates) + return w, vals, density_matrix, adam_state + + update = jax.jit(update, backend=opts.backend) + + for i in pbar: + state = states[i%opts.states] + w, vals, density_matrix, adam_state = update(w, state, adam_state) + + # valid + if i % 10 == 0 and opts.nn: + _, (valid_vals, _, _) = valf(w, val_state, opts.normal, opts.nn) + print("validation:") + str = "error=" + "".join(["%.7f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(1, opts.val_bs)]) + " [eV]" + print(str) + print() + if opts.wandb: + dct = {} + for i in range(1, opts.val_bs): + dct['valid_l%i'%i ] = valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i] + wandb.log(dct) + + if opts.nn and opts.wandb and i > 0: + current_lr = custom_schedule(i, d_model, 4000) + wandb.log({'lr': current_lr}) + + if opts.bs == 1: pbar.set_description("error=%.7f [eV] (%.7f %.7f) "%(np.mean(val*HARTREE_TO_EV-state.pyscf_E), val*HARTREE_TO_EV, state.pyscf_E)) + else: + if opts.wandb: + wandb.log( + {'l1': vals[0]*HARTREE_TO_EV-state.pyscf_E[0], + 'l2': vals[1]*HARTREE_TO_EV-state.pyscf_E[1]}) + + str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" + #str += "E_xc=" + "".join(["%.7f "%(E_xc[i]*HARTREE_TO_EV) for i in range(opts.bs)]) + " [eV]" + try: + mins = np.minimum(mins, np.abs(vals*HARTREE_TO_EV - state.pyscf_E[:, 0])) + str += " best=" + "".join(["%.7f "%(mins[i]) for i in range(2)]) + " [eV]" + except: + pass + pbar.set_description(str) + if i == 0: print("") + + if val < min_val: + min_val = val + min_dm = density_matrix + + + val, density_matrix = min_val, min_dm + + # needs batching + exit() + V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) + V_xc = (V_xc + V_xc.T)/2 + diff_JK = get_JK(density_matrix, state.ERI) + H = state.H_core + diff_JK + V_xc + mo_energy, mo_coeff = np.linalg.eigh(state.L_inv @ H @ state.L_inv.T) + mo_coeff = state.L_inv.T @ mo_coeff + + return val, (0, mo_energy, mo_coeff, state.grid_coords, state.grid_weights, density_matrix, H) + + +import chex +@chex.dataclass +class IterationState: + init: np.array + E_nuc: np.array + L_inv: np.array + L_inv_T: np.array + H_core: np.array + grid_AO: np.array + grid_weights: np.array + grid_coords: np.array + pyscf_E: np.array + N: int + ERI: np.array + nonzero_distinct_ERI: list + nonzero_indices: list + diffs_ERI: np.array + main_grid_AO: np.array + diffs_grid_AO: np.array + indxs: np.array + sparse_diffs_grid_AO: np.array + rows: np.array + cols: np.array + pos: np.array + ao_types: np.array + +from pyscf.data.elements import charge as elements_proton +from pyscf.dft import gen_grid, radi + +def treutler_atomic_radii_adjust(mol, atomic_radii): + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + a = jnp.array(a) + + def fadjust(i, j, g): + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j] + g1 += g + return g1 + + return fadjust + + +def inter_distance(coords): + rr = np.linalg.norm(coords.reshape(-1, 1, 3) - coords, axis=2) + rr[np.diag_indices(rr.shape[0])] = 0 + return rr + +def original_becke(g): + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + return g + +def get_partition( + mol, + atom_coords, + atom_grids_tab, + radii_adjust=treutler_atomic_radii_adjust, + atomic_radii=radi.BRAGG_RADII, + becke_scheme=original_becke, + concat=True +): + atm_dist = inter_distance(atom_coords) # [natom, natom] + + def gen_grid_partition(coords): + ngrids = coords.shape[0] + dc = coords[None] - atom_coords[:, None] + grid_dist = np.sqrt(np.einsum('ijk,ijk->ij', dc, dc)) # [natom, ngrid] + + ix, jx = np.tril_indices(mol.natm, k=-1) + + natm, ngrid = grid_dist.shape + #g_ = -1 / atm_dist.reshape(natm, natm, 1) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) + g_ = -1 / (atm_dist.reshape(natm, natm, 1) + np.eye(natm).reshape(natm, natm,1)) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) + #g_ = jnp.array(g_) + + def pbecke_g(i, j): + g = g_[i, j] + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j].reshape(-1, 1) + g1 += g + return g1 + + g = pbecke_g(ix, jx) + g = np.copy(becke_scheme(g)) + gp2 = (1+g)/2 + gm2 = (1-g)/2 + + pbecke = jnp.ones((mol.natm, ngrids)) # [natom, ngrid] + pbecke = pbecke.at[ix].mul(gm2) + pbecke = pbecke.at[jx].mul(gp2) + + return pbecke + + coords_all = [] + weights_all = [] + for ia in range(mol.natm): + coords, vol = atom_grids_tab[mol.atom_symbol(ia)] + coords = coords + atom_coords[ia] # [ngrid, 3] + pbecke = gen_grid_partition(coords) # [natom, ngrid] + weights = vol * pbecke[ia] / jnp.sum(pbecke, axis=0) + coords_all.append(coords) + weights_all.append(weights) + + if concat: + coords_all = jnp.vstack(coords_all) + weights_all = jnp.hstack(weights_all) + return coords_all, weights_all + + +class DifferentiableGrids(gen_grid.Grids): + """Differentiable alternative to the original pyscf.gen_grid.Grids.""" + + def build(self, atom_coords) : + mol = self.mol + + atom_grids_tab = self.gen_atomic_grids( + mol, self.atom_grid, self.radi_method, self.level, self.prune + ) + + coords, weights = get_partition( + mol, + atom_coords, + atom_grids_tab, + treutler_atomic_radii_adjust, + self.atomic_radii, + original_becke, + ) + self.coords = coords + self.weights = weights + return coords, weights + + +def grids_from_pyscf_mol( + mol: pyscf.gto.mole.Mole, quad_level: int = 1 +) : + g = gen_grid.Grids(mol) + g.level = quad_level + g.build() + grids = jnp.array(g.coords) + weights = jnp.array(g.weights) + return grids, weights + + +def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True): + mol = build_mol(mol_str, opts.basis) + if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) + + N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) + n_electrons_half = mol.nelectron//2 # 21 for C6H6 + E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. + + from pyscf import dft + #grids = pyscf.dft.gen_grid.Grids(mol) + grids = DifferentiableGrids(mol) + grids.level = opts.level + #grids.build() + grids.build(np.concatenate([np.array(a[1]).reshape(1, 3) for a in mol._atom])) + + grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6 + grid_coords = grids.coords + coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' + grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. + + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. + kinetic = mol.intor_symmetric('int1e_kin') # (N,N) + nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) + O = mol.intor_symmetric('int1e_ovlp') # (N,N) + L = np.linalg.cholesky(O) + L_inv = np.linalg.inv(L) # (N,N) + init = np.eye(N)[:, :n_electrons_half] + + if opts.normal: + ERI = mol.intor("int2e_sph") + nonzero_distinct_ERI = np.zeros(1) + nonzero_indices = np.zeros(1) + else: + eri_threshold = 0 + batches = 1 + nipu = 1 + nonzero_distinct_ERI = mol.intor("int2e_sph", aosym="s8") + #ERI = [nonzero_distinct_ERI, nonzero_indices] + #ERI = ERI + ERI = np.zeros(1) + #ERI = mol.intor("int2e_sph") + + def e(x): return np.expand_dims(x, axis=0) + + n_C = nao('C', opts.basis) + n_N = nao('N', opts.basis) + n_O = nao('O', opts.basis) + n_F = nao('F', opts.basis) + n_H = nao('H', opts.basis) + n_vocab = n_C + n_N + n_O + n_F + n_H + start, stop = 0, n_C + c = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_N + n = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_O + o = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_F + f = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_H + h = list(range(n_vocab))[start:stop] + types = [] + pos = [] + for a, p in mol_str: + if a.lower() == 'h': + types += h + pos += [np.array(p).reshape(1, -1)]*len(h) + elif a.lower() == 'c': + types += c + pos += [np.array(p).reshape(1, -1)]*len(c) + elif a.lower() == 'n': + types += n + pos += [np.array(p).reshape(1, -1)]*len(n) + elif a.lower() == 'o': + types += o + pos += [np.array(p).reshape(1, -1)]*len(o) + elif a.lower() == 'f': + types += f + pos += [np.array(p).reshape(1, -1)]*len(f) + else: raise Exception() + ao_types = np.array(types) + pos = np.concatenate(pos) + + state = IterationState( + diffs_ERI = np.zeros((1,1)), + main_grid_AO = np.zeros((1,1)), + diffs_grid_AO = np.zeros((1,1)), + indxs = np.zeros((1,1)), + sparse_diffs_grid_AO = np.zeros((1,1)), + rows = np.zeros((1,1)), + cols = np.zeros((1,1)), + pos=e(pos), + ao_types=e(ao_types), + init = e(init), + E_nuc=e(E_nuc), + ERI=e(ERI), + nonzero_distinct_ERI=[nonzero_distinct_ERI], + nonzero_indices=[0], + H_core=e(nuclear+kinetic), + L_inv=e(L_inv), + L_inv_T = e(L_inv.T), + grid_AO=e(grid_AO), + grid_weights=e(grid_weights), + grid_coords=e(grid_coords), + pyscf_E=e(pyscf_E[-1:]), + N=e(mol.nao_nr()), + ) + + + return state + + + +def summary(state): + if state is None: return + print("_"*100) + total = 0 + for field_name, field_def in state.__dataclass_fields__.items(): + field_value = getattr(state, field_name) + try: + print("%20s %20s %20s"%(field_name,getattr(field_value, 'shape', None), getattr(field_value, "nbytes", None)/10**9)) + total += getattr(field_value, "nbytes", None)/10**9 + + except: + try: + print("%20s %20s %20s"%(field_name,getattr(field_value[0], 'shape', None), getattr(field_value[0], "nbytes", None)/10**9)) + total += getattr(field_value, "nbytes", None)/10**9 + except: + print("BROKE FOR ", field_name) + + print("%20s %20s %20s"%("-", "total", total)) + try: + print(state.pyscf_E[:, -1]) + except: + pass + print("_"*100) + +def _cat(x,y,name): + if "list" in str(type(x)): + return x + y + else: + return np.concatenate([x,y]) + +def cat(dc1, dc2, axis=0): + # Use dictionary comprehension to iterate over the dataclass fields + concatenated_fields = { + field: _cat(getattr(dc1, field), getattr(dc2, field), field) + for field in dc1.__annotations__ + } + # Create a new dataclass instance with the concatenated fields + return IterationState(**concatenated_fields) + +def _cats(xs): + if "list" in str(type(xs[0])): + return sum(xs, [])#x + y + else: + return np.concatenate(xs) + + +def cats(dcs): + concatenated_fields = { + field: _cats([getattr(dc, field) for dc in dcs]) + for field in dcs[0].__annotations__ + } + # Create a new dataclass instance with the concatenated fields + return IterationState(**concatenated_fields) + + + +def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): + # Electronic part of RHF/RKS gradients + dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + dme0 = 2 * (mo_coeff * mask*mo_energy) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + + # Code identical to exchange correlation. + rho = jnp.sum( grid_AO[:1] @ dm0 * grid_AO, axis=2) # (10, grid_size) = (10, 45624) for C6H6. + _, vrho, vgamma = vxc_b3lyp(rho, EPSILON_B3LYP) # (grid_size,) (grid_size,) + V_xc = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma.reshape(1, -1)*rho[1:4]], axis=0) # (4, grid_size) + + vmat = grid_AO[1:4].transpose(0, 2, 1) @ jnp.sum(grid_AO[:4] * jnp.expand_dims(weight * V_xc, axis=2), axis=0) # (3, N, N) + aos = jnp.concatenate([jnp.expand_dims(grid_AO[np.array([1,4,5,6])], 0), jnp.expand_dims(grid_AO[np.array([2,5,7,8])], 0), jnp.expand_dims(grid_AO[np.array([3,6,8,9])], 0)], axis=0) # (3, N, N) + V_xc = - vmat - jnp.transpose(jnp.einsum("snpi,np->spi", aos, weight*V_xc), axes=(0,2,1)) @ grid_AO[0] # (3, 4, grid_size, N) + + vj = - jnp.einsum('sijkl,lk->sij', eri, dm0) # (3, N, N) + vk = - jnp.einsum('sijkl,jk->sil', eri, dm0) # (3, N, N) + vhf = V_xc + vj - vk * .5 * HYB_B3LYP # (3, N, N) + + de = jnp.einsum('lxij,ij->lx', h1aos, dm0) # (natm, 3) + for k, ia in enumerate(range(natm)): + p0, p1 = aoslices[ia][2], aoslices[ia][3] + de = de.at[k].add(jnp.einsum('xij,ij->x', vhf[:, p0:p1], dm0[p0:p1]) * 2) + de = de.at[k].add(-jnp.einsum('xij,ij->x', s1[:, p0:p1], dme0[p0:p1]) * 2) + return de + +def grad_nuc(charges, coords): + # Derivatives of nuclear repulsion energy wrt nuclear coordinates + natm = charges.shape[0] + pairwise_charges = charges.reshape(natm, 1) * charges.reshape(1, natm) # (natm, natm) + pairwise_difference = coords.reshape(1, natm, 3) - coords.reshape(natm, 1, 3) # (natm, natm, 3) + pairwise_distances = jnp.linalg.norm(pairwise_difference, axis=2) ** 3 # (natm, natm) + pairwise_distances = jnp.where(pairwise_distances == 0, jnp.inf, pairwise_distances) # (natm, natm) + all = - pairwise_charges.reshape(natm, natm, 1) * pairwise_difference # (natm, natm, 3) + all = all / pairwise_distances.reshape(natm, natm, 1) # (natm, natm, 3) + all = all.at[jnp.diag_indices(natm)].set(0) # (natm, natm, 3) + return jnp.sum(all, axis=0) # (natm, natm) + +def grad(mol, coords, weight, mo_coeff, mo_energy, dm, H): + # Initialize DFT tensors on CPU using PySCF. + ao = pyscf.dft.numint.NumInt().eval_ao(mol, coords, deriv=2) + eri = mol.intor("int2e_ip1") + s1 = - mol.intor('int1e_ipovlp', comp=3) + kin = - mol.intor('int1e_ipkin', comp=3) + nuc = - mol.intor('int1e_ipnuc', comp=3) + + aoslices = mol.aoslice_by_atom() + h1 = kin + nuc + def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> + _, _, p0, p1 = aoslices[atm_id] + with mol.with_rinv_at_nucleus(atm_id): + vrinv = mol.intor('int1e_iprinv', comp=3) # + vrinv *= -mol.atom_charge(atm_id) + vrinv[:,p0:p1] += h1[:,p0:p1] + return vrinv + vrinv.transpose(0,2,1) + N = h1.shape[1] # (3, N , N) + h1aos = np.zeros((mol.natm, 3, N, N)) + for k, ia in enumerate(range(mol.natm)): + p0, p1 = aoslices[ia,2:] + h1aos[k] = hcore_deriv(ia, aoslices, h1) + + charges = np.zeros((mol.natm)) + coords = np.zeros((mol.natm,3)) + for j in range(mol.natm): + charges[j] = mol.atom_charge(j) + coords[j]= mol.atom_coord(j) + + #_grad_elec = jax.jit(grad_elec, static_argnames=["aoslices", "natm"], backend="cpu") + _grad_elec = grad_elec + _grad_nuc = jax.jit(grad_nuc, backend="cpu") + + return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff, mol, dm, H) + _grad_nuc(charges, coords) + +def pyscf_reference(mol_str, opts): + from pyscf import __config__ + __config__.dft_rks_RKS_grids_level = opts.level + mol = build_mol(mol_str, opts.basis) + mol.max_cycle = 50 + mf = pyscf.scf.RKS(mol) + mf.max_cycle = 50 + mf.xc = "b3lyp5" + mf.diis_space = 8 + pyscf_energies = [] + pyscf_hlgaps = [] + lumo = mol.nelectron//2 + homo = lumo - 1 + def callback(envs): + pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) + hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV + pyscf_hlgaps.append(hl_gap_hartree) + print("\rPYSCF: ", pyscf_energies[-1] , end="") + mf.callback = callback + mf.kernel() + print("") + if False: + forces = mf.nuc_grad_method().kernel() + else: forces = 0 + return np.array(pyscf_energies), np.array(pyscf_hlgaps), np.array(forces) + +def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap): + #TODO(HH): rename to match caller variable names + nanoDFT_E = nanoDFT_E*HARTREE_TO_EV + print("pyscf:\t\t%15f"%pyscf_E[-1]) + print("us:\t\t%15f"%nanoDFT_E) + print("diff:\t\t%15f"%np.abs(pyscf_E[-1]-nanoDFT_E)) + print("chemAcc: \t%15f"%0.043) + print("chemAcc/diff: \t%15f"%(0.043/np.abs(pyscf_E[-1]-nanoDFT_E))) + print("") + + # Forces + print() + print("np.max(|nanoDFT_F-PySCF_F|):", np.max(np.abs(nanoDFT_forces-pyscf_forces))) + + norm_X = np.linalg.norm(nanoDFT_forces, axis=1) + norm_Y = np.linalg.norm(pyscf_forces, axis=1) + dot_products = np.sum(nanoDFT_forces * pyscf_forces, axis=1) + cosine_similarity = dot_products / (norm_X * norm_Y) + print("Force cosine similarity:",cosine_similarity) + +def build_mol(mol_str, basis_name): + mol = pyscf.gto.mole.Mole() + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + return mol + +def reference(mol_str, opts): + import pickle + import hashlib + if opts.skip: return np.zeros(1), np.zeros(1), np.zeros(1) + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + print(filename) + if not os.path.exists(filename): + pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) + with open(filename, "wb") as file: + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + else: + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + return pyscf_E, pyscf_hlgap, pyscf_forces + + +if __name__ == "__main__": + import os + import argparse + + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-steps', type=int, default=100000) + parser.add_argument('-bs', type=int, default=2) + parser.add_argument('-val_bs', type=int, default=4) + + parser.add_argument('-normal', action="store_true") + parser.add_argument('-wandb', action="store_true") + parser.add_argument('-prof', action="store_true") + parser.add_argument('-visualize', action="store_true") + parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + parser.add_argument('-repeats', type=int, default=1, help="times to repeat molecule") + + # dataset + parser.add_argument('-benzene', action="store_true") + parser.add_argument('-states', type=int, default=1) + parser.add_argument('-wiggle_var', type=float, default=1.0, help="wiggle N(0, wiggle_var)") + parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") + + # models + parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") + parser.add_argument('-tiny', action="store_true") + parser.add_argument('-small', action="store_true") + parser.add_argument('-base', action="store_true") + opts = parser.parse_args() + if opts.tiny or opts.small or opts.base: opts.nn = True + + if True: + df = pd.read_pickle("alchemy/atom_9.pickle") + df = df[df["spin"] == 0] # only consider spin=0 + mol_strs = df["pyscf"].values + + # benzene + if opts.benzene: + mol_strs = [[ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ]] + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) + + exit() + pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) + nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py new file mode 100644 index 00000000..5486b1da --- /dev/null +++ b/pyscf_ipu/direct/transformer.py @@ -0,0 +1,193 @@ +""" Pure-from-the-ground-up transformer, based on https://github.com/vpj/jax_transformer/blob/master/transformer.py """ +import jax +from jax import vmap +import jax.numpy as jnp +from functools import partial +import jax.experimental.host_callback +import math +import numpy as np + +def rand(rng, f, shape, **kwargs): + rng, rng1 = jax.random.split(rng) + return rng, f(rng1, shape, **kwargs) + +def linear_init_uniform(rng: jax.random.KeyArray, in_features: int, out_features: int): + params = ParamsDict() + rnd_range = 1 / in_features**0.5 + rng, params.weight = rand( rng, jax.random.uniform, (in_features, out_features), minval=-rnd_range, maxval=rnd_range,) + params.bias = jnp.zeros((out_features,)) + return rng, params, (in_features, out_features) + +def elementwise_linear_init_identity(shape): return ParamsDict(gain=jnp.ones(shape), bias=jnp.zeros(shape)) + +def linear(params, x: jnp.ndarray): return x @ params.weight + params.bias[None, :] + +def elementwise_linear(params, x: jnp.ndarray): return params.gain[None, :] * x + params.bias[None, :] + +def standardize(x, eps=1e-5): return (x - x.mean()) / (x.std() + eps) + +def transformer_init( + rng: jax.random.KeyArray, + n_vocab: int, + d_model: int, + n_layers: int, + n_heads: int, + d_ff: int, + max_len=4096, +): + total_params = 0 + + # Build config struct for call + config = ParamsDict() + config.heads = n_heads + if True: #flip_pe_coef(): + config.lambda_e = d_model**-0.5 + config.lambda_pe = 1.0 + else: + config.lambda_e = d_model**-0.5 + config.lambda_pe = 1.0 + + # Build initializers for params + params = ParamsDict() + + print("_"*100) + + # Create embedding layer + rng, params.embeddings = rand(rng, jax.random.normal, (n_vocab, d_model)) + total_params += np.prod(params.embeddings.shape) + print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) + + + # For transformer layers + params.layers = [] + for i in range(n_layers): + layer = ParamsDict() + layer.norm_self_attn = elementwise_linear_init_identity(d_model) + total_params += np.prod(d_model*2) + print("%26s %26s %26s"%("layer%i.norm_self_attn"%i, (d_model,2), np.prod((d_model, 2)))) + + rng, layer.kqv, shape = linear_init_uniform(rng, d_model, d_model*3) + total_params += np.prod(shape) # omitting bias in calculation for now + print("%26s %26s %26s"%("layer%i.kqv"%i, shape, np.prod(shape))) + + layer.norm_ff = elementwise_linear_init_identity(d_model) + total_params += np.prod(d_model*2) + print("%26s %26s %26s"%("layer%i.norm_ff"%i, (d_model,2), np.prod((d_model, 2)))) + + rng, layer.ffn1, shape = linear_init_uniform(rng, d_model, d_ff) + total_params += np.prod(shape) + print("%26s %26s %26s"%("layer%i.ffn1"%i, shape, np.prod(shape))) + + rng, layer.ffn2, shape = linear_init_uniform(rng, d_ff, d_model) + total_params += np.prod(shape) + print("%26s %26s %26s"%("layer%i.ffn2"%i, shape, np.prod(shape))) + + params.layers.append(layer) + + # Final normalization and output layer + print("total: ", total_params) + + return rng, config, params, total_params + + +@partial(jax.jit, static_argnums=0) +def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray): + """ + cfg: Config, from transformer_init, holds hyperparameters + params: Current transformer parameters, initialized in init + x: 1D array of L integers, representing the input sequence + output: L x n_vocab logits + """ + L, = x.shape # x is just 1D. Vmap/pmap will handle batching + + embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm + + # Add (learned) positional encodings + x = jnp.concatenate([embeddings[:, :-3], position], -1) + L, dm = x.shape + + # Apply the transformer layers + for layer in params.layers: + # Layer-normalize embeddings + #t1 = vmap(standardize)(embeddings) + t1 = elementwise_linear(layer.norm_self_attn, x) # L x Dm + + L, Dm = t1.shape + nheads = cfg.heads + qkv = linear(layer.kqv, t1)#.reshape(L, Dm, 3) + #q, k, v = [qkv[:, :, i].reshape(nheads, L, Dm//nheads) for i in range(3)] + q = jnp.transpose(qkv[:, 0*Dm:1*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) + k = jnp.transpose(qkv[:, 1*Dm:2*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) + v = jnp.transpose(qkv[:, 2*Dm:3*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) + score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) + + if layer == 0: # doesn't look like it helps + score += H_core + + attn = jax.nn.softmax(score, axis=1) + x = x + (attn @ v).reshape(L, Dm) + + # Layer-normalize embeddings + #t2 = vmap(standardize)(embeddings) + t2 = elementwise_linear(layer.norm_ff, x) # L x Dm + + # Feedforward fully connected + t2 = linear(layer.ffn1, t2) # L x Dm*4 + t2 = jax.nn.gelu(t2) + t2 = linear(layer.ffn2, t2) # L x Dm + + # Add this layer's contribution into embeddings + x = x + t2 + + return score #attn #linear(params.output, embeddings) # L x n_vocab + + +import types +import json +import jax + +import numbers + +def is_simple_type(x): + return isinstance(x, (numbers.Number, bool, str)) + +@jax.tree_util.register_pytree_node_class +class ParamsDict(types.SimpleNamespace): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def tree_flatten(self): + return jax.tree_flatten(self.__dict__, lambda a: a is not self.__dict__) # only flatten one step + + @classmethod + def tree_unflatten(cls, aux, values): + return ParamsDict(**jax.tree_unflatten(aux, values)) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, + sort_keys=True, indent=4) + + def __hash__(self): + # Should overload setattr to warn if setattr is called after hash has been computed + return hash(tuple(hash(x) for (_,x) in self.__dict__.items())) + + def print(self, path = ''): + for (k,v) in self.items(path): + print(k + ':',v) + + @classmethod + def labels_aux(cls, path, obj): + if isinstance(obj, (list, tuple)) and any(not is_simple_type(x) for x in obj): + for i,vi in enumerate(obj): + yield from cls.labels_aux(f'{path}[{i}]', vi) + elif isinstance(obj, dict): + for (k,v) in obj.items(): + yield from cls.labels_aux(path + '/' + k, v) + elif isinstance(obj, ParamsDict): + yield from cls.labels_aux(path, obj.__dict__) + else: + yield (path, obj) + + def items(self, path = ''): + yield from self.labels_aux(path, self) + From 67e0e99164fa71f3ece11d81120702f8ebdcce6a Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Mon, 4 Dec 2023 15:08:29 +0000 Subject: [PATCH 05/22] reproduce qm9 --- pyscf_ipu/direct/qm9/download.sh | 3 +++ pyscf_ipu/direct/qm9/reproduce.py | 42 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) create mode 100755 pyscf_ipu/direct/qm9/download.sh create mode 100644 pyscf_ipu/direct/qm9/reproduce.py diff --git a/pyscf_ipu/direct/qm9/download.sh b/pyscf_ipu/direct/qm9/download.sh new file mode 100755 index 00000000..7332f63e --- /dev/null +++ b/pyscf_ipu/direct/qm9/download.sh @@ -0,0 +1,3 @@ +wget https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/gdb9.tar.gz +tar -xvzf gdb9.tar.gz +wget https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv diff --git a/pyscf_ipu/direct/qm9/reproduce.py b/pyscf_ipu/direct/qm9/reproduce.py new file mode 100644 index 00000000..d2b446e3 --- /dev/null +++ b/pyscf_ipu/direct/qm9/reproduce.py @@ -0,0 +1,42 @@ +# reproduce qm9 labels; run download.sh to download qm9 first. +from pyscf import gto, scf, dft, __config__ +import pyscf +import pandas as pd +print(pyscf.__version__) +df = pd.read_csv('qm9.csv') +qm9_energy = df['u0'][0] - df['zpve'][0] +qm9_hlgap = df['gap'][0] + +mol = gto.Mole() +mol.atom = ''' + C -0.0127 1.0858 0.0080 + H 0.0022 -0.0060 0.0020 + H 1.0117 1.4638 0.0003 + H -0.5408 1.4475 -0.8766 + H -0.5238 1.4379 0.9064 +''' +mol.basis = '6-31G(2df,p)' +mol.build() + +# Run B3LYP calculation +method = dft.RKS(mol) +method.verbose = 4 +method.xc = 'B3LYPG' # b3lypG (G as in gaussain) +method.max_cycle = 50 +method.DIIS = pyscf.scf.diis.CDIIS +method.small_rho_cutoff = 1e-10 +method.diis_space = 8 +method.diis_start_cycle = 1 +method.damp = 5e-1 # damping factor +method.conv_tol = 1e-9 +method.conv_tol_grad = None # 1e-9 +method.grids.level = 3 +method.kernel() + +# Get total energy and HOMO-LUMO gap +energy = method.e_tot +homo, lumo = method.mo_energy[method.mo_occ>0].max(), method.mo_energy[method.mo_occ==0].min() +hlgap = lumo - homo + +print('qm9\t %10f %10f'%(qm9_energy, qm9_hlgap)) +print('pyscf\t %10f %10f'%( energy, hlgap)) From 9efde5ccc8096795f1128a49c4b1d09bf802930f Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 7 Dec 2023 09:42:59 +0000 Subject: [PATCH 06/22] Prepares tensors asynch/parallel using pytorch dataloader. --- .../direct/exchange_correlation/b3lyp.py | 5 +- pyscf_ipu/direct/sparse_symmetric_ERI.py | 5 - pyscf_ipu/direct/train.py | 423 +++++++++++++----- pyscf_ipu/direct/transformer.py | 11 +- 4 files changed, 310 insertions(+), 134 deletions(-) diff --git a/pyscf_ipu/direct/exchange_correlation/b3lyp.py b/pyscf_ipu/direct/exchange_correlation/b3lyp.py index 827236d3..77a3ad01 100644 --- a/pyscf_ipu/direct/exchange_correlation/b3lyp.py +++ b/pyscf_ipu/direct/exchange_correlation/b3lyp.py @@ -102,11 +102,14 @@ def __lda(rho): return -jnp.exp(1/3*jnp.log(rho) - 0.30305460484554375) CLIP_RHO_MIN = 1e-10 CLIP_RHO_MAX = 1e15 -def b3lyp(rho, EPSILON_B3LYP=0): +def _b3lyp(rho, EPSILON_B3LYP=0): rho0 = jnp.clip(rho[0], CLIP_RHO_MIN, CLIP_RHO_MAX) norms = jnp.linalg.norm(rho[1:4]*2+CLIP_RHO_MIN, axis=0).T**2+EPSILON_B3LYP return __lda(rho0)*0.08 + (__vwn(rho0)*0.19 + __b88(rho0, norms)*0.72 + __lyp(rho0, norms)*0.81) / rho0 +def b3lyp(rho0, norms, EPSILON_B3LYP=0): + return __lda(rho0)*0.08 + (__vwn(rho0)*0.19 + __b88(rho0, norms)*0.72 + __lyp(rho0, norms)*0.81) / rho0 + def vxc_b3lyp(rho, EPSILON_B3LYP=0): rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) diff --git a/pyscf_ipu/direct/sparse_symmetric_ERI.py b/pyscf_ipu/direct/sparse_symmetric_ERI.py index 3efea154..f7646c5d 100644 --- a/pyscf_ipu/direct/sparse_symmetric_ERI.py +++ b/pyscf_ipu/direct/sparse_symmetric_ERI.py @@ -7,23 +7,18 @@ from icecream import ic HYB_B3LYP = 0.2 -@partial(jax.jit, backend="cpu") def get_i_j(val): - import jax.numpy as np i = (np.sqrt(1 + 8*val.astype(np.uint64)) - 1)//2 # no need for floor, integer division acts as floor. j = (((val - i) - (i**2 - val))//2) return i, j def ijkl(value, symmetry, N, f): - import jax.numpy as np i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) return f(i,j,k,l,symmetry,N) ijkl = jax.vmap(ijkl, in_axes=(0, None, None, None)) -@partial(jax.jit, backend="cpu") def num_repetitions_fast(ij, kl): - import jax.numpy as np i, j = get_i_j(ij) k, l = get_i_j(kl) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 18b99e5c..830b4052 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -5,12 +5,20 @@ import pyscf import optax from icecream import ic -from exchange_correlation.b3lyp import b3lyp, vxc_b3lyp +from exchange_correlation.b3lyp import b3lyp, _b3lyp, vxc_b3lyp from tqdm import tqdm import time from transformer import transformer, transformer_init import pandas as pd + +# [ ] add train loss +# [ ] (algo) may get transformer to be 4x faster by fixing AO_grid stuff; +# [ ] (algo) can speed up dataloader by re-using AO_grid init (and ERI etc). +# [ ] (opt/ml) the learning rate schedule is weird when we do multiple steps per batch +# [ ] (opt/ml) validate with 10 molecules, not just 1. +# [ ] add dropout/dropres/... + cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 def T(x): return jnp.transpose(x, (0,2,1)) @@ -20,11 +28,11 @@ def T(x): return jnp.transpose(x, (0,2,1)) # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. def dm_energy(W: BxNxK, state, normal, nn): if nn: - W = jnp.mean(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, W, state.ao_types, state.pos, state.H_core) , axis=1) + W = jnp.mean(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ + W, state.ao_types, state.pos, state.H_core) , axis=1) W = W @ state.init - - L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO - density_matrix: BxNxN = 2 * L_inv_Q @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO diff_JK: BxNxN = JK(density_matrix, state, normal) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) @@ -36,26 +44,20 @@ def sparse_mult(values, dm, state, gsize): prod = in_*values[:, None] return jax.ops.segment_sum(prod, state.rows, gsize) -def exchange_correlation(density_matrix, state, normal): +def exchange_correlation(density_matrix: BxNxN, state, normal): _, _, gsize, N = state.grid_AO.shape B = density_matrix.shape[0] if normal: grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (B, N, N) = O(B gsize N^2) rho = jnp.sum(grid_AO_dm * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) else: - main = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs - correction = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) + main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs + correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) + rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) + rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) + rho = rho_a - rho_b - # subtract before/after einsum. - if True: - grid_AO_dm = (main - correction).reshape(B, 1, gsize, N) # (B * gsize * N) - rho = jnp.einsum("bpij,bqij->bpi", state.grid_AO, grid_AO_dm) - else: - rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) - rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) - rho = rho_a - rho_b - - E_xc = jax.vmap(b3lyp, in_axes=(0,None))(rho, EPSILON_B3LYP).reshape(B, gsize) + E_xc = jax.vmap(_b3lyp, in_axes=(0, None))(rho, EPSILON_B3LYP).reshape(B, gsize) E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) return E_xc @@ -66,9 +68,6 @@ def JK(density_matrix, state, normal): diff_JK = J - K / 2 * HYB_B3LYP else: from sparse_symmetric_ERI import sparse_symmetric_einsum - # batched => flops = reads - #diff_JK = jax.vmap(sparse_symmetric_einsum, in_axes=(0, 0, 0))(state.nonzero_distinct_ERI, state.nonzero_indices, density_matrix) - # first + correction_remaining => floats = reads*batch_size diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0))(state.nonzero_distinct_ERI[0], state.nonzero_indices[0], density_matrix) diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0))(state.diffs_ERI, state.indxs, density_matrix) @@ -79,22 +78,42 @@ def nao(atom, basis): m.build() return m.nao_nr()//2 -def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): +def batched_state(mol_str, opts, bs, wiggle_num=0, + do_pyscf=True, extrapolate=False, validation=False, + pad_electrons=40, + pad_diff_ERIs=50000, + pad_distinct_ERIs=200000, + pad_grid_AO=22000, + pad_nonzero_distinct_ERI=200000, + pad_sparse_diff_grid=200000, + ): + if opts.wandb: import wandb + #pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ + # -1, -1, -1, -1, -1, -1 t0 = time.time() - state = init_dft(mol_str, opts, do_pyscf=do_pyscf) + state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights - np.random.seed(42) p = np.array(mol_str[0][1]) states = [state] - for i in tqdm(range(bs-1)): - x = p + np.random.normal(0, opts.wiggle_var, (3)) + #for i in tqdm(range(bs-1)): + for i in range(bs-1): + if opts.benzene: + x = p + x[2] += np.random.normal(0, opts.wiggle_var*(1+ (extrapolate and i > bs//2)), (1)) + else: + x = p + np.random.normal(0, opts.wiggle_var, (3)) + + if extrapolate and i > bs//2: + x = p + np.random.normal(0, opts.wiggle_var, (3)) + mol_str[wiggle_num][1] = (x[0], x[1], x[2]) # when profiling create fake molecule to skip waiting - if i == 0 or not opts.prof: - stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 5) + if i == 0 or not opts.prof: + stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 2, state=state, + pad_electrons=pad_electrons) states.append(stateB) @@ -103,7 +122,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): # Compute ERI sparsity. nonzero = [] - for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): + for e, i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): abs = np.abs(e) indxs = abs < opts.eri_threshold #1e-10 e[indxs] = 0 @@ -146,6 +165,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): main_grid_AO = state.grid_AO[:1] diffs_grid_AO = main_grid_AO - state.grid_AO rows, cols = np.nonzero(np.max(diffs_grid_AO[:, 0]!=0, axis=0)) + sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] # use the same sparsity pattern across a batch. @@ -164,7 +184,19 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): diff_ERIs = diff_ERIs.reshape(bs, batches, -1) diff_indxs = diff_indxs.reshape(batches, -1, 4) - state.indxs=diff_indxs + if pad_diff_ERIs == -1: + state.indxs=diff_indxs + state.diffs_ERI=diff_ERIs + else: + # pad ERIs with 0 and indices with -1 so they point to 0. + assert diff_indxs.shape[1] == diff_ERIs.shape[2] + pad = pad_diff_ERIs - diff_indxs.shape[1] + assert pad > 0 + state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) + state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + + if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + state.rows=rows state.cols=cols @@ -172,11 +204,45 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): state.sparse_diffs_grid_AO = sparse_diffs_grid_AO state.diffs_grid_AO = diffs_grid_AO - state.diffs_ERI=diff_ERIs + + if pad_sparse_diff_grid != -1: + assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] + assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] + pad = pad_sparse_diff_grid - state.rows.shape[0] + assert pad >= 0 + state.rows = np.pad(state.rows, (0,pad)) + state.cols = np.pad(state.cols, (0,pad)) + state.sparse_diffs_grid_AO = np.pad(state.sparse_diffs_grid_AO, ((0,0),(0,pad))) + + if opts.wandb: wandb.log({"pad_sparse_diff_grid": pad/state.sparse_diffs_grid_AO.shape[1]}) #state.grid_AO = state.grid_AO[:1] state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[:1] state.nonzero_indices = np.expand_dims(state.nonzero_indices, axis=0) + if pad_distinct_ERIs != -1: + assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] + pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] + assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) + state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0), (0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0), (0,0), (0, pad))) # pad zeros + + if opts.wandb: wandb.log({"pad_distinct_ERIs": pad/state.nonzero_distinct_ERI.shape[2]}) + + if pad_grid_AO != -1: + assert state.grid_AO.shape[2] == state.grid_weights.shape[1] + assert state.grid_AO.shape[2] == state.grid_coords.shape[1] + assert state.grid_AO.shape[2] == state.main_grid_AO.shape[1] + assert state.grid_AO.shape[2] == state.diffs_grid_AO.shape[2] + pad = pad_grid_AO - state.grid_AO.shape[2] + assert pad > 0 + state.grid_AO = np.pad(state.grid_AO, ((0,0),(0,0), (0,pad), (0,0))) + state.grid_weights = np.pad(state.grid_weights, ((0,0),(0,pad))) + state.grid_coords = np.pad(state.grid_coords, ((0,0),(0,pad),(0,0))) + state.main_grid_AO = np.pad(state.main_grid_AO, ((0,0),(0,pad),(0,0))) + state.diffs_grid_AO = np.pad(state.diffs_grid_AO, ((0,0),(0,0),(0,pad),(0,0))) + + if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) + indxs = np.abs(state.nonzero_distinct_ERI ) > 1e-9 state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[indxs] @@ -190,13 +256,17 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) - print("batch: ", time.time()-t0) - return state + if pad_nonzero_distinct_ERI != -1: + assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] + pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] + assert pad >= 0, (pad_nonzero_distinct_ERI, state.nonzero_distinct_ERI.shape[2]) + state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) + state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) + if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) -# todo: wiggle asynch each 10 steps or so. -def wiggle(state): # idea: keep bs=64 state, but add say another 64 wiggles which change. - pass + + return state def nanoDFT(mol_str, opts): @@ -205,23 +275,9 @@ def nanoDFT(mol_str, opts): # This consists of DFT tensors initialized with PySCF/CPU. np.random.seed(42) - - # Step 1. - # Take benzene; randomly sample first carbon. - # Get model generalize to new randomly sampled carbons. - # - # Step 2. - # Change benzene to have CNOF. - # Do single atom wiggles on CNOF, get model to generalize to simultaneous wiggles. - # - # Step 3/4. Scale grid-size and basis set. - - # Data Creation / Data loader - # [ ] Move to dataloader; whenever a new data point is ready, use that. - # [ ] Consider pre-compute/load. - # [ ] Consider 'wiggle' function which generate new wiggle datapoints in state. - val_state = batched_state(mol_str[0], opts, opts.val_bs, do_pyscf=True) - states = [batched_state(mol_str[0], opts, opts.bs, do_pyscf=True)] + [batched_state(mol_str[i], opts, opts.bs, do_pyscf=False) for i in range(opts.states-1)] + if opts.wandb: + import wandb + wandb.init(project='ndft') rnd_key = jax.random.PRNGKey(42) n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ @@ -265,51 +321,108 @@ def custom_schedule(step_num, d_model, warmup_steps): arg2 = step_num * warmup_steps ** -1.5 return d_model ** -0.5 * min(arg1, arg2) - optimizer = optax.adam(learning_rate=lambda step: custom_schedule(step, d_model=d_model, warmup_steps=4000), + optimizer = optax.adam(learning_rate = opts.lr,#)lambda step: custom_schedule(step, d_model=d_model, warmup_steps=4000), b1=0.9, b2=0.98, eps=1e-9) adam = optax.adam(opts.lr) w = params + + from torch.utils.data import DataLoader, Dataset + class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=1000): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + df = df.sample(frac=1).reset_index(drop=True) + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.opts.bs, wiggle_num=0, do_pyscf=self.validation, extrapolate=False, validation=False) + + qm9 = OnTheFlyQM9(opts, train=True) + train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=5, prefetch_factor=2, collate_fn=lambda x: x[0]) + pbar = tqdm(train_dataloader) + val_qm9 = OnTheFlyQM9(opts, train=False) + else: + states = [batched_state(mol_str[0], opts, opts.bs, do_pyscf=True)] + [batched_state(mol_str[i], opts, opts.bs, do_pyscf=False) for i in range(opts.states-1)] + class DummyIterator: + def __init__(self, item): self.item = item + def __iter__(self): return self + def __next__(self): return self.item + train_dataloader = DummyIterator(states[0]) + pbar = tqdm(train_dataloader) w = states[0].init adam = optax.adabelief(opts.lr) + summary(states[0]) vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) - - # Build initializers for params adam_state = adam.init(w) - min_val = 0 - min_dm = 0 - - pbar = tqdm(range(opts.steps)) - - summary(states[0]) - - mins = np.ones(opts.bs) * 1e6 - if opts.wandb: - import wandb - wandb.init(project='ndft') - wandb.log({'total_params': total_params}) + if opts.wandb: wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) - print("jitting...") - t0 = time.time() - (val, (vals, E_xc, density_matrix)), grad = vandg(w, states[0], opts.normal, opts.nn) - print("done!", time.time()-t0) - - def update(w, state, adam_state): + print("rejitting... (if this is printed more than once something is wrong!)") (val, (vals, E_xc, density_matrix)), grad = vandg(w, state, opts.normal, opts.nn) updates, adam_state = adam.update(grad, adam_state) w = optax.apply_updates(w, updates) return w, vals, density_matrix, adam_state - update = jax.jit(update, backend=opts.backend) - for i in pbar: - state = states[i%opts.states] - w, vals, density_matrix, adam_state = update(w, state, adam_state) + min_val, min_dm, mins, val_state, valid_str, step = 0, 0, np.ones(opts.bs)*1e6, None, "", 0 + t0, load_time, train_time, val_time = time.time(), 0, 0, 0 + + # training loop (epochs are handled inside dataloader due to segfault). + for i, state in enumerate(pbar): + load_time, t0 = time.time()-t0, time.time() + + for j in range(16): + if j == 0: _t0 =time.time() + w, vals, density_matrix, adam_state = update(w, state, adam_state) + if j == 0: time_step1 = time.time()-_t0 + step += 1 + if not opts.nn: + str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" + pbar.set_description(str) + else: + pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + " " + valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + + train_time, t0 = time.time()-t0, time.time() + + if opts.nn:# and i % 5 == 0: + if val_state is None: val_state = val_qm9[0] + _, (valid_vals, _, _) = valf(w, val_state, opts.normal, opts.nn) + valid_str = "val_error=" + "".join(["%.7f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + if opts.wandb: + dct = {} + for i in range(0, opts.val_bs): + dct['valid_l%i'%i ] = valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i] + dct["lr"] = custom_schedule(step, d_model, 4000) + wandb.log(dct) + + valid_time, t0 = time.time()-t0, time.time() + continue + + # move to train loop; this is done inside dataloader multithreaded + '''if opts.wandb and i < 6: + from plot import create_rdkit_mol + import wandb + str = [mol_str[j][0] for j in range(len(mol_str))] + pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], i): create_rdkit_mol(str, pos) })''' # valid if i % 10 == 0 and opts.nn: @@ -339,6 +452,8 @@ def update(w, state, adam_state): #str += "E_xc=" + "".join(["%.7f "%(E_xc[i]*HARTREE_TO_EV) for i in range(opts.bs)]) + " [eV]" try: mins = np.minimum(mins, np.abs(vals*HARTREE_TO_EV - state.pyscf_E[:, 0])) + if np.max(mins) < 1e-5: + break str += " best=" + "".join(["%.7f "%(mins[i]) for i in range(2)]) + " [eV]" except: pass @@ -367,6 +482,7 @@ def update(w, state, adam_state): import chex @chex.dataclass class IterationState: + mask: np.array init: np.array E_nuc: np.array L_inv: np.array @@ -424,23 +540,13 @@ def original_becke(g): g = (3 - g**2) * g * .5 return g -def get_partition( - mol, - atom_coords, - atom_grids_tab, - radii_adjust=treutler_atomic_radii_adjust, - atomic_radii=radi.BRAGG_RADII, - becke_scheme=original_becke, - concat=True -): - atm_dist = inter_distance(atom_coords) # [natom, natom] - - def gen_grid_partition(coords): +def gen_grid_partition(coords, atom_coords, natm, atm_dist, elements, + atomic_radii, becke_scheme=original_becke,): ngrids = coords.shape[0] dc = coords[None] - atom_coords[:, None] grid_dist = np.sqrt(np.einsum('ijk,ijk->ij', dc, dc)) # [natom, ngrid] - ix, jx = np.tril_indices(mol.natm, k=-1) + ix, jx = np.tril_indices(natm, k=-1) natm, ngrid = grid_dist.shape #g_ = -1 / atm_dist.reshape(natm, natm, 1) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) @@ -449,7 +555,7 @@ def gen_grid_partition(coords): def pbecke_g(i, j): g = g_[i, j] - charges = [elements_proton(x) for x in mol.elements] + charges = [elements_proton(x) for x in elements] rad = np.sqrt(atomic_radii[charges]) + 1e-200 rr = rad.reshape(-1, 1) * (1. / rad) a = .25 * (rr.T - rr) @@ -466,32 +572,65 @@ def pbecke_g(i, j): gp2 = (1+g)/2 gm2 = (1-g)/2 - pbecke = jnp.ones((mol.natm, ngrids)) # [natom, ngrid] + t0 = time.time() + #pbecke = f(gm2, gp2, natm, ngrids, ix, jx ) + pbecke = np.ones((natm, ngrids)) + c = 0 + # this goes up to n choose two + for i in range(natm): + for j in range(i): + pbecke[i] *= gm2[c] + pbecke[j] *= gp2[c] + c += 1 + #print("\t", time.time()-t0) + return pbecke + +from functools import partial +@partial(jax.jit, backend="gpu", static_argnums=(2,3)) +def f(gm2, gp2, natm, ngrids, ix, jx): + pbecke = jnp.ones((natm, ngrids)) # [natom, ngrid] pbecke = pbecke.at[ix].mul(gm2) pbecke = pbecke.at[jx].mul(gp2) + return pbecke - return pbecke + +def get_partition( + mol, + atom_coords, + atom_grids_tab, + radii_adjust=treutler_atomic_radii_adjust, + atomic_radii=radi.BRAGG_RADII, + becke_scheme=original_becke, + concat=True, state=None +): + t0 = time.time() + atm_dist = inter_distance(atom_coords) # [natom, natom] coords_all = [] weights_all = [] + + # [ ] consider another grid? for ia in range(mol.natm): coords, vol = atom_grids_tab[mol.atom_symbol(ia)] coords = coords + atom_coords[ia] # [ngrid, 3] - pbecke = gen_grid_partition(coords) # [natom, ngrid] - weights = vol * pbecke[ia] / jnp.sum(pbecke, axis=0) + pbecke = gen_grid_partition(coords, atom_coords, mol.natm, atm_dist, mol.elements, atomic_radii) # [natom, ngrid] + weights = vol * pbecke[ia] / np.sum(pbecke, axis=0) coords_all.append(coords) weights_all.append(weights) if concat: - coords_all = jnp.vstack(coords_all) - weights_all = jnp.hstack(weights_all) + coords_all = np.vstack(coords_all) + weights_all = np.hstack(weights_all) + + coords = (coords_all, weights_all) return coords_all, weights_all class DifferentiableGrids(gen_grid.Grids): """Differentiable alternative to the original pyscf.gen_grid.Grids.""" - def build(self, atom_coords) : + def build(self, atom_coords, state=None) : + t0 = time.time() mol = self.mol atom_grids_tab = self.gen_atomic_grids( @@ -505,7 +644,9 @@ def build(self, atom_coords) : treutler_atomic_radii_adjust, self.atomic_radii, original_becke, + state=state, ) + self.coords = coords self.weights = weights return coords, weights @@ -522,7 +663,8 @@ def grids_from_pyscf_mol( return grids, weights -def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True): +def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): + #t0 = time.time() mol = build_mol(mol_str, opts.basis) if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) @@ -536,7 +678,7 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T grids = DifferentiableGrids(mol) grids.level = opts.level #grids.build() - grids.build(np.concatenate([np.array(a[1]).reshape(1, 3) for a in mol._atom])) + grids.build(np.concatenate([np.array(a[1]).reshape(1, 3) for a in mol._atom]), state=state) grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6 grid_coords = grids.coords @@ -549,7 +691,15 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T O = mol.intor_symmetric('int1e_ovlp') # (N,N) L = np.linalg.cholesky(O) L_inv = np.linalg.inv(L) # (N,N) - init = np.eye(N)[:, :n_electrons_half] + + if pad_electrons == -1: + init = np.eye(N)[:, :n_electrons_half] + mask = np.ones((1, n_electrons_half)) + else: + assert pad_electrons > n_electrons_half + init = np.eye(N)[:, :pad_electrons] + mask = np.zeros((1, pad_electrons)) + mask[:, :n_electrons_half] = 1 if opts.normal: ERI = mol.intor("int2e_sph") @@ -616,25 +766,25 @@ def e(x): return np.expand_dims(x, axis=0) pos=e(pos), ao_types=e(ao_types), init = e(init), - E_nuc=e(E_nuc), - ERI=e(ERI), - nonzero_distinct_ERI=[nonzero_distinct_ERI], - nonzero_indices=[0], - H_core=e(nuclear+kinetic), - L_inv=e(L_inv), - L_inv_T = e(L_inv.T), - grid_AO=e(grid_AO), - grid_weights=e(grid_weights), - grid_coords=e(grid_coords), - pyscf_E=e(pyscf_E[-1:]), - N=e(mol.nao_nr()), - ) + E_nuc=e(E_nuc), + ERI=e(ERI), + nonzero_distinct_ERI=[nonzero_distinct_ERI], + nonzero_indices=[0], + H_core=e(nuclear+kinetic), + L_inv=e(L_inv), + L_inv_T = e(L_inv.T), + grid_AO=e(grid_AO), + grid_weights=e(grid_weights), + grid_coords=e(grid_coords), + pyscf_E=e(pyscf_E[-1:]), + N=e(mol.nao_nr()), + mask=e(mask), + ) return state - def summary(state): if state is None: return print("_"*100) @@ -842,30 +992,35 @@ def reference(mol_str, opts): parser.add_argument('-backend', type=str, default="cpu") parser.add_argument('-lr', type=float, default=2.5e-4) parser.add_argument('-steps', type=int, default=100000) - parser.add_argument('-bs', type=int, default=2) - parser.add_argument('-val_bs', type=int, default=4) + parser.add_argument('-bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=8) parser.add_argument('-normal', action="store_true") parser.add_argument('-wandb', action="store_true") parser.add_argument('-prof', action="store_true") parser.add_argument('-visualize', action="store_true") parser.add_argument('-skip', action="store_true", help="skip pyscf test case") - parser.add_argument('-repeats', type=int, default=1, help="times to repeat molecule") # dataset parser.add_argument('-benzene', action="store_true") + parser.add_argument('-hydrogens', action="store_true") + parser.add_argument('-water', action="store_true") + parser.add_argument('-waters', action="store_true") parser.add_argument('-states', type=int, default=1) - parser.add_argument('-wiggle_var', type=float, default=1.0, help="wiggle N(0, wiggle_var)") + parser.add_argument('-wiggle_var', type=float, default=0.3, help="wiggle N(0, wiggle_var)") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") # models - parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") + parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") parser.add_argument('-tiny', action="store_true") - parser.add_argument('-small', action="store_true") + parser.add_argument('-small', action="store_true") parser.add_argument('-base', action="store_true") opts = parser.parse_args() if opts.tiny or opts.small or opts.base: opts.nn = True + args_dict = vars(opts) + print(args_dict) + if True: df = pd.read_pickle("alchemy/atom_9.pickle") df = df[df["spin"] == 0] # only consider spin=0 @@ -887,6 +1042,30 @@ def reference(mol_str, opts): ["H", ( 3.2000, 1.2124, 0.0000)], ["H", ( 1.9500, 3.3775, 0.0000)] ]] + # hydrogens + if opts.hydrogens: + mol_strs = [[ + ["H", ( 0.0000, 0.0000, 0.0000)], + ["H", ( 1.4000, 0.0000, 0.0000)], + ]] + if opts.water: + mol_strs = [[ + ["O", ( 0.0000, 0.0000, 0.0000)], + ["H", ( 0.0000, 1.4000, 0.0000)], + ["H", ( 1.4000, 0.0000, 0.0000)], + ]] + if opts.waters: + mol_strs = [[ + ["O",(-0.1858140, -1.1749469, 0.7662596)], + ["H",(-0.1285513, -0.8984365, 1.6808606)], + ["H",(-0.0582782, -0.3702550, 0.2638279)], + ["O",( 0.1747051, 1.1050002, -0.7244430)], + ["H",(-0.5650842, 1.3134964, -1.2949455)], + ["H",( 0.9282185, 1.0652990, -1.3134026)], ]] + + + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index 5486b1da..6d127476 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -57,7 +57,6 @@ def transformer_init( total_params += np.prod(params.embeddings.shape) print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) - # For transformer layers params.layers = [] for i in range(n_layers): @@ -109,7 +108,7 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. # Apply the transformer layers for layer in params.layers: # Layer-normalize embeddings - #t1 = vmap(standardize)(embeddings) + t1 = vmap(standardize)(embeddings) t1 = elementwise_linear(layer.norm_self_attn, x) # L x Dm L, Dm = t1.shape @@ -128,7 +127,7 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. x = x + (attn @ v).reshape(L, Dm) # Layer-normalize embeddings - #t2 = vmap(standardize)(embeddings) + t2 = vmap(standardize)(embeddings) t2 = elementwise_linear(layer.norm_ff, x) # L x Dm # Feedforward fully connected @@ -139,7 +138,7 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. # Add this layer's contribution into embeddings x = x + t2 - return score #attn #linear(params.output, embeddings) # L x n_vocab + return score import types @@ -157,11 +156,11 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def tree_flatten(self): - return jax.tree_flatten(self.__dict__, lambda a: a is not self.__dict__) # only flatten one step + return jax.tree_util.tree_flatten(self.__dict__, lambda a: a is not self.__dict__) # only flatten one step @classmethod def tree_unflatten(cls, aux, values): - return ParamsDict(**jax.tree_unflatten(aux, values)) + return ParamsDict(**jax.tree_util.tree_unflatten(aux, values)) def toJSON(self): return json.dumps(self, default=lambda o: o.__dict__, From 6799289c470944ffa35fcd2d6ec9e36fe6a34868 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 7 Dec 2023 09:45:30 +0000 Subject: [PATCH 07/22] Removed extrapolate. --- pyscf_ipu/direct/train.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 830b4052..55f36aea 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -79,7 +79,7 @@ def nao(atom, basis): return m.nao_nr()//2 def batched_state(mol_str, opts, bs, wiggle_num=0, - do_pyscf=True, extrapolate=False, validation=False, + do_pyscf=True, validation=False, pad_electrons=40, pad_diff_ERIs=50000, pad_distinct_ERIs=200000, @@ -101,20 +101,15 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, for i in range(bs-1): if opts.benzene: x = p - x[2] += np.random.normal(0, opts.wiggle_var*(1+ (extrapolate and i > bs//2)), (1)) + x[2] += np.random.normal(0, opts.wiggle_var, (1)) else: x = p + np.random.normal(0, opts.wiggle_var, (3)) - if extrapolate and i > bs//2: - x = p + np.random.normal(0, opts.wiggle_var, (3)) - mol_str[wiggle_num][1] = (x[0], x[1], x[2]) # when profiling create fake molecule to skip waiting - if i == 0 or not opts.prof: - stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 2, state=state, - pad_electrons=pad_electrons) - + if i == 0 or not opts.prof: stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 2, state=state, pad_electrons=pad_electrons) + states.append(stateB) state = cats(states) @@ -349,7 +344,7 @@ def __len__(self): return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): - return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.opts.bs, wiggle_num=0, do_pyscf=self.validation, extrapolate=False, validation=False) + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.opts.bs, wiggle_num=0, do_pyscf=self.validation, validation=False) qm9 = OnTheFlyQM9(opts, train=True) train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=5, prefetch_factor=2, collate_fn=lambda x: x[0]) From 0481471df8df7e7621977b67c3e9f84b273b7170 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 7 Dec 2023 10:09:12 +0000 Subject: [PATCH 08/22] fixed benzene padding. --- pyscf_ipu/direct/train.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 55f36aea..9f2af491 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -88,8 +88,12 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_sparse_diff_grid=200000, ): if opts.wandb: import wandb - #pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ - # -1, -1, -1, -1, -1, -1 + + # don't pad for individual molecules like benzene + if opts.benzene: + pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ + -1, -1, -1, -1, -1, -1 + t0 = time.time() state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights @@ -340,6 +344,22 @@ def __init__(self, opts, nao=294, train=True, num_epochs=1000): self.opts = opts self.validation = not train + self.bezene = [[ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ]] + + def __len__(self): return len(self.mol_strs)*self.num_epochs From 15c64cc9a3b445b1a3067549bce0c5e8d1eadbfc Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 7 Dec 2023 14:00:56 +0000 Subject: [PATCH 09/22] added water rotation and fixed random seed bug. --- pyscf_ipu/direct/train.py | 160 +++++++++++++++++++++++--------- pyscf_ipu/direct/transformer.py | 17 +++- 2 files changed, 129 insertions(+), 48 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 9f2af491..ba246d3a 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -11,7 +11,6 @@ from transformer import transformer, transformer_init import pandas as pd - # [ ] add train loss # [ ] (algo) may get transformer to be 4x faster by fixing AO_grid stuff; # [ ] (algo) can speed up dataloader by re-using AO_grid init (and ERI etc). @@ -28,8 +27,8 @@ def T(x): return jnp.transpose(x, (0,2,1)) # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. def dm_energy(W: BxNxK, state, normal, nn): if nn: - W = jnp.mean(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ - W, state.ao_types, state.pos, state.H_core) , axis=1) + W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ + W, state.ao_types, state.pos, state.H_core) W = W @ state.init L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO @@ -37,7 +36,7 @@ def dm_energy(W: BxNxK, state, normal, nn): diff_JK: BxNxN = JK(density_matrix, state, normal) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) energy: float = jnp.sum(energies) - return energy, (energies, E_xc, density_matrix) + return energy, (energies, E_xc, density_matrix, W) def sparse_mult(values, dm, state, gsize): in_ = dm.take(state.cols, axis=0) @@ -86,22 +85,41 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_grid_AO=22000, pad_nonzero_distinct_ERI=200000, pad_sparse_diff_grid=200000, + mol_idx=42, ): if opts.wandb: import wandb - # don't pad for individual molecules like benzene - if opts.benzene: + # pad molecule if using nn. + if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ -1, -1, -1, -1, -1, -1 + if opts.benzene and opts.nn: pad_electrons = 30 + if opts.hydrogens: + pad_diff_ERIs = 5000 + pad_distinct_ERIs = 20000 + pad_grid_AO = 2200 + pad_nonzero_distinct_ERI = 20000 + pad_sparse_diff_grid = 20000 + + if opts.waters: + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] + + mol = build_mol(mol_str, opts.basis) + pad_electrons = min(pad_electrons, mol.nao_nr()) + t0 = time.time() state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights - np.random.seed(42) + # make sure we do different rotation; prev all worker did same rotation! + np.random.seed(mol_idx) p = np.array(mol_str[0][1]) states = [state] - #for i in tqdm(range(bs-1)): + if opts.waters: + water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) + water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) + for i in range(bs-1): if opts.benzene: x = p @@ -109,7 +127,23 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, else: x = p + np.random.normal(0, opts.wiggle_var, (3)) - mol_str[wiggle_num][1] = (x[0], x[1], x[2]) + if len(x) > 1: mol_str[wiggle_num][1] = (x[0], x[1], x[2]) + + if opts.waters: # rotate second water molecule + rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] + center = water2_xyz.mean(axis=0) + water_xyz = np.dot(water2_xyz - center, rotation_matrix) + center + + mol_str[3][1] = tuple(water_xyz[0]) + mol_str[4][1] = tuple(water_xyz[1]) + mol_str[5][1] = tuple(water_xyz[2]) + + if opts.wandb and i == 0: + from plot import create_rdkit_mol + import wandb + str = [mol_str[j][0] for j in range(len(mol_str))] + pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], i): create_rdkit_mol(str, pos) }) # when profiling create fake molecule to skip waiting if i == 0 or not opts.prof: stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 2, state=state, pad_electrons=pad_electrons) @@ -129,8 +163,11 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, # Merge nonzero indices and prepare (ij, kl). # rep is the number of repetitions we include in the sparse representation. - # TODO: the union1d should include all nonzero, not just first. - nonzero_indices = np.union1d(nonzero[0], nonzero[1]) + #nonzero_indices = np.union1d(nonzero[0], nonzero[1]) + union = nonzero[0] + for i in range(1, len(nonzero)): + union = np.union1d(union, nonzero[i]) + nonzero_indices = union from sparse_symmetric_ERI import get_i_j, num_repetitions_fast ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) @@ -315,15 +352,18 @@ def nanoDFT(mol_str, opts): if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer - def custom_schedule(step_num, d_model, warmup_steps): + def custom_schedule(step_num: int, warmup_steps=4000*3): arg1 = step_num ** -0.5 arg2 = step_num * warmup_steps ** -1.5 - return d_model ** -0.5 * min(arg1, arg2) + return d_model ** -0.5 * np.min(arg1, arg2) - optimizer = optax.adam(learning_rate = opts.lr,#)lambda step: custom_schedule(step, d_model=d_model, warmup_steps=4000), - b1=0.9, b2=0.98, eps=1e-9) + # problem: loss improves, then, at some point, it starts decreasing. + # obs: happens with attention output and mlp output! + # obs: adam/sgd/adabelief ? [ ] + # obs: tiny/small/medium - adam = optax.adam(opts.lr) + adam = optax.adam(learning_rate=opts.lr, b1=0.9, b2=0.98, eps=1e-9) + #adam = optax.adabelief(opts.lr) w = params from torch.utils.data import DataLoader, Dataset @@ -344,7 +384,7 @@ def __init__(self, opts, nao=294, train=True, num_epochs=1000): self.opts = opts self.validation = not train - self.bezene = [[ + self.benzene = [ ["C", ( 0.0000, 0.0000, 0.0000)], ["C", ( 1.4000, 0.0000, 0.0000)], ["C", ( 2.1000, 1.2124, 0.0000)], @@ -357,14 +397,27 @@ def __init__(self, opts, nao=294, train=True, num_epochs=1000): ["H", (-1.8000, 1.2124, 0.0000)], ["H", ( 3.2000, 1.2124, 0.0000)], ["H", ( 1.9500, 3.3775, 0.0000)] - ]] - + ] + self.waters = [ + ["O", (-1.464, 0.099, 0.300)], + ["H", (-1.956, 0.624, -0.340)], + ["H", (-1.797, -0.799, 0.206)], + ["O", ( 1.369, 0.146, -0.395)], + ["H", ( 1.894, 0.486, 0.335)], + ["H", ( 0.451, 0.165, -0.083)] + ] + + if opts.benzene: self.mol_strs = [self.benzene] + if opts.waters: self.mol_strs = [self.waters] + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs def __len__(self): return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): - return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.opts.bs, wiggle_num=0, do_pyscf=self.validation, validation=False) + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, wiggle_num=0, do_pyscf=self.validation, validation=False, mol_idx=idx) qm9 = OnTheFlyQM9(opts, train=True) train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=5, prefetch_factor=2, collate_fn=lambda x: x[0]) @@ -387,45 +440,68 @@ def __next__(self): return self.item valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) adam_state = adam.init(w) - if opts.wandb: wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) + if opts.wandb: + if not opts.nn: total_params = -1 + wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) def update(w, state, adam_state): - print("rejitting... (if this is printed more than once something is wrong!)") - (val, (vals, E_xc, density_matrix)), grad = vandg(w, state, opts.normal, opts.nn) + print("jitting... (if this is printed more than once something is wrong!)") + (val, (vals, E_xc, density_matrix, W)), grad = vandg(w, state, opts.normal, opts.nn) updates, adam_state = adam.update(grad, adam_state) w = optax.apply_updates(w, updates) - return w, vals, density_matrix, adam_state + return w, vals, density_matrix, W, adam_state update = jax.jit(update, backend=opts.backend) + min_val, min_dm, mins, val_state, valid_str, step = 0, 0, np.ones(opts.bs)*1e6, None, "", 0 t0, load_time, train_time, val_time = time.time(), 0, 0, 0 # training loop (epochs are handled inside dataloader due to segfault). - for i, state in enumerate(pbar): + for iteration, state in enumerate(pbar): load_time, t0 = time.time()-t0, time.time() + # idea: [ ] do gradient accumulation last 16 iterations. for j in range(16): if j == 0: _t0 =time.time() - w, vals, density_matrix, adam_state = update(w, state, adam_state) + w, vals, density_matrix, _W, adam_state = update(w, state, adam_state) if j == 0: time_step1 = time.time()-_t0 step += 1 if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" + str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) pbar.set_description(str) + else: - pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + " " + valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + #pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] lr=%5e"%custom_schedule(step) + valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + + if opts.wandb: + dct = {} + for i in range(0, opts.bs): + if not opts.nn: + dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) + dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) + dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) + dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) + dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) + wandb.log(dct) + train_time, t0 = time.time()-t0, time.time() if opts.nn:# and i % 5 == 0: if val_state is None: val_state = val_qm9[0] - _, (valid_vals, _, _) = valf(w, val_state, opts.normal, opts.nn) + _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) valid_str = "val_error=" + "".join(["%.7f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" if opts.wandb: dct = {} for i in range(0, opts.val_bs): - dct['valid_l%i'%i ] = valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i] - dct["lr"] = custom_schedule(step, d_model, 4000) + dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) + dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) + dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) + dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) + dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) + #dct["lr"] = custom_schedule(step) wandb.log(dct) valid_time, t0 = time.time()-t0, time.time() @@ -482,8 +558,8 @@ def update(w, state, adam_state): val, density_matrix = min_val, min_dm - # needs batching exit() + # needs batching V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) V_xc = (V_xc + V_xc.T)/2 diff_JK = get_JK(density_matrix, state.ERI) @@ -830,6 +906,7 @@ def _cat(x,y,name): else: return np.concatenate([x,y]) + def cat(dc1, dc2, axis=0): # Use dictionary comprehension to iterate over the dataclass fields concatenated_fields = { @@ -945,7 +1022,7 @@ def callback(envs): pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV pyscf_hlgaps.append(hl_gap_hartree) - print("\rPYSCF: ", pyscf_energies[-1] , end="") + print("\rPYSCF: ", pyscf_energies[-1] , "[eV]", end="") mf.callback = callback mf.kernel() print("") @@ -1017,6 +1094,7 @@ def reference(mol_str, opts): parser.add_argument('-skip', action="store_true", help="skip pyscf test case") # dataset + parser.add_argument('-qm9', action="store_true") parser.add_argument('-benzene', action="store_true") parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") @@ -1036,7 +1114,7 @@ def reference(mol_str, opts): args_dict = vars(opts) print(args_dict) - if True: + if opts.qm9: df = pd.read_pickle("alchemy/atom_9.pickle") df = df[df["spin"] == 0] # only consider spin=0 mol_strs = df["pyscf"].values @@ -1071,16 +1149,12 @@ def reference(mol_str, opts): ]] if opts.waters: mol_strs = [[ - ["O",(-0.1858140, -1.1749469, 0.7662596)], - ["H",(-0.1285513, -0.8984365, 1.6808606)], - ["H",(-0.0582782, -0.3702550, 0.2638279)], - ["O",( 0.1747051, 1.1050002, -0.7244430)], - ["H",(-0.5650842, 1.3134964, -1.2949455)], - ["H",( 0.9282185, 1.0652990, -1.3134026)], ]] - - - - + ["O", (-1.464, 0.099, 0.300)], + ["H", (-1.956, 0.624, -0.340)], + ["H", (-1.797, -0.799, 0.206)], + ["O", ( 1.369, 0.146, -0.395)], + ["H", ( 1.894, 0.486, 0.335)], + ["H", ( 0.451, 0.165, -0.083)]]] nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index 6d127476..9e6a5ab0 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -101,12 +101,15 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm + all_pairs = jnp.linalg.norm( position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) + # Add (learned) positional encodings - x = jnp.concatenate([embeddings[:, :-3], position], -1) + x = jnp.concatenate([embeddings[:, :-3], position*10], -1) # seems positions ignored, scale larger (nn has to learn unit anyway) + #x = embeddings L, dm = x.shape # Apply the transformer layers - for layer in params.layers: + for layer_num, layer in enumerate(params.layers): # Layer-normalize embeddings t1 = vmap(standardize)(embeddings) t1 = elementwise_linear(layer.norm_self_attn, x) # L x Dm @@ -120,10 +123,12 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. v = jnp.transpose(qkv[:, 2*Dm:3*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) - if layer == 0: # doesn't look like it helps + # do like graphformer and append position here? + if layer_num < 6: # doesn't look like it helps score += H_core + score += all_pairs - attn = jax.nn.softmax(score, axis=1) + attn = jax.nn.softmax(score , axis=1) x = x + (attn @ v).reshape(L, Dm) # Layer-normalize embeddings @@ -138,7 +143,9 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. # Add this layer's contribution into embeddings x = x + t2 - return score + return score[0] # take first head + #print(score.shape, x.shape, x[:L,:L].shape) + #return x[:L, :L] #score import types From bebcaeef2997f3674355e5a3f111d994f4a8a4c8 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 7 Dec 2023 15:05:36 +0000 Subject: [PATCH 10/22] rotating water dimer solved. --- pyscf_ipu/direct/train.py | 74 +++++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 35 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index ba246d3a..daae41ab 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -11,13 +11,14 @@ from transformer import transformer, transformer_init import pandas as pd -# [ ] add train loss # [ ] (algo) may get transformer to be 4x faster by fixing AO_grid stuff; # [ ] (algo) can speed up dataloader by re-using AO_grid init (and ERI etc). # [ ] (opt/ml) the learning rate schedule is weird when we do multiple steps per batch # [ ] (opt/ml) validate with 10 molecules, not just 1. # [ ] add dropout/dropres/... +# [ ] + cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 def T(x): return jnp.transpose(x, (0,2,1)) @@ -121,13 +122,14 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) for i in range(bs-1): - if opts.benzene: + '''if opts.benzene: x = p x[2] += np.random.normal(0, opts.wiggle_var, (1)) else: x = p + np.random.normal(0, opts.wiggle_var, (3)) - if len(x) > 1: mol_str[wiggle_num][1] = (x[0], x[1], x[2]) + if not opts.water: + mol_str[wiggle_num][1] = (x[0], x[1], x[2])''' if opts.waters: # rotate second water molecule rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] @@ -357,10 +359,8 @@ def custom_schedule(step_num: int, warmup_steps=4000*3): arg2 = step_num * warmup_steps ** -1.5 return d_model ** -0.5 * np.min(arg1, arg2) - # problem: loss improves, then, at some point, it starts decreasing. - # obs: happens with attention output and mlp output! - # obs: adam/sgd/adabelief ? [ ] - # obs: tiny/small/medium + # [ ] todo: add loss schedule. + # [ ] todo: train on qm9 instead of water (how do we rotate there?). adam = optax.adam(learning_rate=opts.lr, b1=0.9, b2=0.98, eps=1e-9) #adam = optax.adabelief(opts.lr) @@ -372,7 +372,7 @@ class OnTheFlyQM9(Dataset): # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples - def __init__(self, opts, nao=294, train=True, num_epochs=1000): + def __init__(self, opts, nao=294, train=True, num_epochs=10**9): # only take molecules with use {CNOFH}, nao=nao and spin=0. df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules if nao != -1: df = df[df["nao"]==nao] @@ -457,34 +457,45 @@ def update(w, state, adam_state): t0, load_time, train_time, val_time = time.time(), 0, 0, 0 # training loop (epochs are handled inside dataloader due to segfault). + states = [] for iteration, state in enumerate(pbar): load_time, t0 = time.time()-t0, time.time() # idea: [ ] do gradient accumulation last 16 iterations. - for j in range(16): + states = states[-opts.grad_acc:] + [state] + accumulated_grad = None + print(len(states)) + for j, state in enumerate(states): if j == 0: _t0 =time.time() - w, vals, density_matrix, _W, adam_state = update(w, state, adam_state) + #w, vals, density_matrix, _W, adam_state = update(w, state, adam_state) + (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) if j == 0: time_step1 = time.time()-_t0 step += 1 - if not opts.nn: - str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" - str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) - pbar.set_description(str) + accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - else: - #pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] lr=%5e"%custom_schedule(step) + valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) - pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + accumulated_grad = jax.tree_map(lambda x: x / len(states), accumulated_grad) + updates, adam_state = adam.update(accumulated_grad, adam_state) + w = optax.apply_updates(w, updates) - if opts.wandb: - dct = {} - for i in range(0, opts.bs): - if not opts.nn: - dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) - dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) - dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) - dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) - dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) - wandb.log(dct) + if not opts.nn: + str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" + str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) + pbar.set_description(str) + + else: + #pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] lr=%5e"%custom_schedule(step) + valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) + + if opts.wandb: + dct = {} + for i in range(0, opts.bs): + if not opts.nn: + dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) + dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) + dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) + dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) + dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) + wandb.log(dct) train_time, t0 = time.time()-t0, time.time() @@ -676,14 +687,6 @@ def pbecke_g(i, j): #print("\t", time.time()-t0) return pbecke -from functools import partial -@partial(jax.jit, backend="gpu", static_argnums=(2,3)) -def f(gm2, gp2, natm, ngrids, ix, jx): - pbecke = jnp.ones((natm, ngrids)) # [natom, ngrid] - pbecke = pbecke.at[ix].mul(gm2) - pbecke = pbecke.at[jx].mul(gp2) - return pbecke - def get_partition( mol, @@ -1086,6 +1089,7 @@ def reference(mol_str, opts): parser.add_argument('-steps', type=int, default=100000) parser.add_argument('-bs', type=int, default=8) parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-grad_acc', type=int, default=16) parser.add_argument('-normal', action="store_true") parser.add_argument('-wandb', action="store_true") From 810be542505e0848d1bc83d30fd486d79f333fe2 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Wed, 20 Dec 2023 19:35:53 +0000 Subject: [PATCH 11/22] water(s) work after avg grad accum, qr -> eigh and gradient norm w/ optax/chainer. added alanine dipeptide w/ rotations. --- pyscf_ipu/direct/alanine.pdb | 22 ++ pyscf_ipu/direct/train.py | 342 +++++++++++++++++++------------- pyscf_ipu/direct/transformer.py | 24 ++- 3 files changed, 248 insertions(+), 140 deletions(-) create mode 100644 pyscf_ipu/direct/alanine.pdb diff --git a/pyscf_ipu/direct/alanine.pdb b/pyscf_ipu/direct/alanine.pdb new file mode 100644 index 00000000..8c9156ea --- /dev/null +++ b/pyscf_ipu/direct/alanine.pdb @@ -0,0 +1,22 @@ +ATOM 1 1HH3 ACE 1 2.000 1.000 -0.000 +ATOM 2 CH3 ACE 1 2.000 2.090 0.000 +ATOM 3 2HH3 ACE 1 1.486 2.454 0.890 +ATOM 4 3HH3 ACE 1 1.486 2.454 -0.890 +ATOM 5 C ACE 1 3.427 2.641 -0.000 +ATOM 6 O ACE 1 4.391 1.877 -0.000 +ATOM 7 N ALA 2 3.555 3.970 -0.000 +ATOM 8 H ALA 2 2.733 4.556 -0.000 +ATOM 9 CA ALA 2 4.853 4.614 -0.000 +ATOM 10 HA ALA 2 5.408 4.316 0.890 +ATOM 11 CB ALA 2 5.661 4.221 -1.232 +ATOM 12 1HB ALA 2 5.123 4.521 -2.131 +ATOM 13 2HB ALA 2 6.630 4.719 -1.206 +ATOM 14 3HB ALA 2 5.809 3.141 -1.241 +ATOM 15 C ALA 2 4.713 6.129 0.000 +ATOM 16 O ALA 2 3.601 6.653 0.000 +ATOM 17 N NME 3 5.846 6.835 0.000 +ATOM 18 H NME 3 6.737 6.359 -0.000 +ATOM 19 CH3 NME 3 5.846 8.284 0.000 +ATOM 20 1HH3 NME 3 4.819 8.648 0.000 +ATOM 21 2HH3 NME 3 6.360 8.648 0.890 +ATOM 22 3HH3 NME 3 6.360 8.648 -0.890 \ No newline at end of file diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index daae41ab..f19ab6ca 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -10,14 +10,7 @@ import time from transformer import transformer, transformer_init import pandas as pd - -# [ ] (algo) may get transformer to be 4x faster by fixing AO_grid stuff; -# [ ] (algo) can speed up dataloader by re-using AO_grid init (and ERI etc). -# [ ] (opt/ml) the learning rate schedule is weird when we do multiple steps per batch -# [ ] (opt/ml) validate with 10 molecules, not just 1. -# [ ] add dropout/dropres/... - -# [ ] +import math cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 @@ -30,13 +23,14 @@ def dm_energy(W: BxNxK, state, normal, nn): if nn: W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ W, state.ao_types, state.pos, state.H_core) - W = W @ state.init - L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + # using eigh + H_core + L_inv made training more stable. + # todo: investigate whether subset is sufficient (e.g. H_core+qr) or we need all. + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO diff_JK: BxNxN = JK(density_matrix, state, normal) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) - energy: float = jnp.sum(energies) + energy: float = jnp.sum(energies) return energy, (energies, E_xc, density_matrix, W) def sparse_mult(values, dm, state, gsize): @@ -80,10 +74,10 @@ def nao(atom, basis): def batched_state(mol_str, opts, bs, wiggle_num=0, do_pyscf=True, validation=False, - pad_electrons=40, + pad_electrons=45, pad_diff_ERIs=50000, - pad_distinct_ERIs=200000, - pad_grid_AO=22000, + pad_distinct_ERIs=120000, + pad_grid_AO=25000, pad_nonzero_distinct_ERI=200000, pad_sparse_diff_grid=200000, mol_idx=42, @@ -103,6 +97,23 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 + if opts.qm9: + pad_electrons=45 + pad_diff_ERIs=120000 + pad_distinct_ERIs=400000 + pad_grid_AO=50000 + pad_nonzero_distinct_ERI=400000 + pad_sparse_diff_grid=400000 + + if opts.alanine: + multiplier = 1.2 + pad_electrons=70 + pad_diff_ERIs=int(180000*1.2) + pad_distinct_ERIs=int(600000*1.2) + pad_grid_AO=int(70000*1.2) + pad_nonzero_distinct_ERI=int(600000*1.2) + pad_sparse_diff_grid=int(1000000*1.2) + if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] @@ -113,7 +124,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights - # make sure we do different rotation; prev all worker did same rotation! + # Set seed to ensure different rotation; initially all workers did same rotation! np.random.seed(mol_idx) p = np.array(mol_str[0][1]) states = [state] @@ -121,17 +132,43 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) - for i in range(bs-1): - '''if opts.benzene: - x = p - x[2] += np.random.normal(0, opts.wiggle_var, (1)) - else: - x = p + np.random.normal(0, opts.wiggle_var, (3)) + if opts.qm9: + atoms = np.array([mol_str[i][1] for i in range(0,3)]) + + for iteration in range(bs-1): + + if opts.alanine: + from rdkit import Chem + from rdkit.Chem import AllChem + pdb_file = 'alanine.pdb' + molecule = Chem.MolFromPDBFile(pdb_file, removeHs=False) + # tried reading from pdb_block; caused parallel dataloader pickle to break. + AllChem.EmbedMolecule(molecule) + AllChem.UFFOptimizeMolecule(molecule) + phi_atoms = [4, 6, 8, 14] # indices for phi dihedral + psi_atoms = [6, 8, 14, 16] # indices for psi dihedral''' + + def xyz(atom): return np.array([atom.x, atom.y, atom.z]).reshape(1, 3) + def get_atom_positions(mol): + conf = mol.GetConformer() + return np.concatenate([xyz(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], axis=0) - if not opts.water: - mol_str[wiggle_num][1] = (x[0], x[1], x[2])''' + str = [mol_str[j][0] for j in range(len(mol_str))] + pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - if opts.waters: # rotate second water molecule + # todo: explore larger space than 10 degree on a single bond! + angle = float(np.random.uniform(0, 10, 1)) - 5 + AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle) + pos = get_atom_positions(molecule) + + for j in range(22): mol_str[j][1] = tuple(pos[j]) + + if iteration == 0: + from plot import create_rdkit_mol + import wandb + wandb.log({"mol_valid=%s_angle=%f"%(validation, angle): create_rdkit_mol(str, pos[:21]) }) + + if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] center = water2_xyz.mean(axis=0) water_xyz = np.dot(water2_xyz - center, rotation_matrix) + center @@ -140,15 +177,34 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, mol_str[4][1] = tuple(water_xyz[1]) mol_str[5][1] = tuple(water_xyz[2]) - if opts.wandb and i == 0: + if opts.wandb and iteration == 0: from plot import create_rdkit_mol import wandb str = [mol_str[j][0] for j in range(len(mol_str))] pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], i): create_rdkit_mol(str, pos) }) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) }) + + elif opts.qm9: + # todo: find dihedral to rotate over similar to alanine dipeptide. + + # rotate first three atoms around their center of mass + # this may break molecule; may need to do this around a bond or smth + #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] + #center = atoms.mean(axis=0) + #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center + mol_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) + mol_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) + mol_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) + + if opts.wandb and iteration == 0: + from plot import create_rdkit_mol + import wandb + str = [mol_str[j][0] for j in range(len(mol_str))] + pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) }) # when profiling create fake molecule to skip waiting - if i == 0 or not opts.prof: stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and i < 2, state=state, pad_electrons=pad_electrons) + if iteration == 0 or not opts.prof: stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) states.append(stateB) @@ -229,7 +285,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, # pad ERIs with 0 and indices with -1 so they point to 0. assert diff_indxs.shape[1] == diff_ERIs.shape[2] pad = pad_diff_ERIs - diff_indxs.shape[1] - assert pad > 0 + assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros @@ -247,7 +303,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] pad = pad_sparse_diff_grid - state.rows.shape[0] - assert pad >= 0 + assert pad >= 0, (pad_sparse_diff_grid, state.rows.shape[0]) state.rows = np.pad(state.rows, (0,pad)) state.cols = np.pad(state.cols, (0,pad)) state.sparse_diffs_grid_AO = np.pad(state.sparse_diffs_grid_AO, ((0,0),(0,pad))) @@ -272,7 +328,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, assert state.grid_AO.shape[2] == state.main_grid_AO.shape[1] assert state.grid_AO.shape[2] == state.diffs_grid_AO.shape[2] pad = pad_grid_AO - state.grid_AO.shape[2] - assert pad > 0 + assert pad > 0, (pad_grid_AO, state.grid_AO.shape[2]) state.grid_AO = np.pad(state.grid_AO, ((0,0),(0,0), (0,pad), (0,0))) state.grid_weights = np.pad(state.grid_weights, ((0,0),(0,pad))) state.grid_coords = np.pad(state.grid_coords, ((0,0),(0,pad),(0,0))) @@ -304,7 +360,8 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) - return state + import copy + return copy.deepcopy(state) def nanoDFT(mol_str, opts): @@ -315,7 +372,10 @@ def nanoDFT(mol_str, opts): if opts.wandb: import wandb - wandb.init(project='ndft') + if opts.alanine: + wandb.init(project='ndft_alanine') + else: + wandb.init(project='ndft') rnd_key = jax.random.PRNGKey(42) n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ @@ -328,19 +388,36 @@ def nanoDFT(mol_str, opts): DeiT-Ti N/A 192 3 12 5M 224 2536 DeiT-S N/A 384 6 12 22M 224 940 DeiT-B ViT-B 768 12 12 86M 224 292 + Parameters Layers dmodel + 117M 12 768 + 345M 24 1024 + 762M 36 1280 + 1542M 48 1600 ''' if opts.tiny: # 5M d_model= 192 n_heads = 6 n_layers = 12 - elif opts.small: + if opts.small: d_model= 384 n_heads = 6 n_layers = 12 - elif opts.base: + if opts.base: d_model= 768 n_heads = 12 n_layers = 12 + if opts.medium: + d_model= 1024 + n_heads = 12 + n_layers = 24 + if opts.large: + d_model= 1280 + n_heads = 12 + n_layers = 36 + if opts.xlarge: + d_model= 1600 + n_heads = 12 + n_layers = 48 if opts.nn: rnd_key, cfg, params, total_params = transformer_init( @@ -354,16 +431,28 @@ def nanoDFT(mol_str, opts): if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer - def custom_schedule(step_num: int, warmup_steps=4000*3): - arg1 = step_num ** -0.5 - arg2 = step_num * warmup_steps ** -1.5 - return d_model ** -0.5 * np.min(arg1, arg2) - - # [ ] todo: add loss schedule. - # [ ] todo: train on qm9 instead of water (how do we rotate there?). + + # try to mimic karpathy as closely as possible ;) + # https://github.com/karpathy/nanoGPT/blob/master/train.py + # still differs on + # [ ] weight initialization + + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): + if it < warmup_iters: return learning_rate * it / warmup_iters + if it > lr_decay_iters: return min_lr + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return min_lr + coeff * (learning_rate - min_lr) + + adam = optax.chain( + optax.clip_by_global_norm(1), + optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + optax.add_decayed_weights(0.1),#, configure_decay_mask(params)), + optax.scale_by_schedule(custom_schedule), + optax.scale(-1), + ) - adam = optax.adam(learning_rate=opts.lr, b1=0.9, b2=0.98, eps=1e-9) - #adam = optax.adabelief(opts.lr) w = params from torch.utils.data import DataLoader, Dataset @@ -409,6 +498,7 @@ def __init__(self, opts, nao=294, train=True, num_epochs=10**9): if opts.benzene: self.mol_strs = [self.benzene] if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str if train: self.bs = opts.bs else: self.bs = opts.val_bs @@ -420,7 +510,10 @@ def __getitem__(self, idx): return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, wiggle_num=0, do_pyscf=self.validation, validation=False, mol_idx=idx) qm9 = OnTheFlyQM9(opts, train=True) - train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=5, prefetch_factor=2, collate_fn=lambda x: x[0]) + if opts.workers != 0: + train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) + else: + train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) val_qm9 = OnTheFlyQM9(opts, train=False) @@ -444,128 +537,80 @@ def __next__(self): return self.item if not opts.nn: total_params = -1 wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) - def update(w, state, adam_state): - print("jitting... (if this is printed more than once something is wrong!)") - (val, (vals, E_xc, density_matrix, W)), grad = vandg(w, state, opts.normal, opts.nn) - updates, adam_state = adam.update(grad, adam_state) - w = optax.apply_updates(w, updates) - return w, vals, density_matrix, W, adam_state - update = jax.jit(update, backend=opts.backend) - min_val, min_dm, mins, val_state, valid_str, step = 0, 0, np.ones(opts.bs)*1e6, None, "", 0 - t0, load_time, train_time, val_time = time.time(), 0, 0, 0 + t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 - # training loop (epochs are handled inside dataloader due to segfault). states = [] for iteration, state in enumerate(pbar): + if iteration == 0: summary(state) + dct = {} + dct["iteraton"] = iteration load_time, t0 = time.time()-t0, time.time() - # idea: [ ] do gradient accumulation last 16 iterations. states = states[-opts.grad_acc:] + [state] accumulated_grad = None - print(len(states)) + + if len(states) < 50: print(len(states)) + for j, state in enumerate(states): if j == 0: _t0 =time.time() - #w, vals, density_matrix, _W, adam_state = update(w, state, adam_state) (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) if j == 0: time_step1 = time.time()-_t0 - step += 1 accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - accumulated_grad = jax.tree_map(lambda x: x / len(states), accumulated_grad) - updates, adam_state = adam.update(accumulated_grad, adam_state) + # scale by global batch size. + global_batch_size = len(states)*opts.bs + wandb.log({"global_batch_size": global_batch_size}) + accumulated_grad = jax.tree_map(lambda x: x / global_batch_size, accumulated_grad) + + # plot grad norm + if iteration % 10 == 0: + for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) + + updates, adam_state = adam.update(accumulated_grad, adam_state, params) w = optax.apply_updates(w, updates) + train_time, t0 = time.time()-t0, time.time() if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) - pbar.set_description(str) - else: - #pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] lr=%5e"%custom_schedule(step) + valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) - pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time)) - - if opts.wandb: - dct = {} - for i in range(0, opts.bs): + pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time, plot_time)) + + if opts.wandb: + dct["time_load"] = load_time + dct["time_step1"] = time_step1 + dct["time_train"] = train_time + dct["time_val"] = val_time + plot_iteration = iteration % 10 == 0 + for i in range(0, 2): if not opts.nn: dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) - dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) - dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) - wandb.log(dct) + if plot_iteration: + dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) + dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) - - train_time, t0 = time.time()-t0, time.time() + plot_time, t0 = time.time()-t0, time.time() if opts.nn:# and i % 5 == 0: if val_state is None: val_state = val_qm9[0] _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) - valid_str = "val_error=" + "".join(["%.7f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + lr = custom_schedule(iteration) + valid_str = "lr=%.3e"%lr + "val_error=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" if opts.wandb: - dct = {} for i in range(0, opts.val_bs): dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) - #dct["lr"] = custom_schedule(step) - wandb.log(dct) + dct["scheduled_lr"] = custom_schedule(iteration) + wandb.log(dct) valid_time, t0 = time.time()-t0, time.time() - continue - - # move to train loop; this is done inside dataloader multithreaded - '''if opts.wandb and i < 6: - from plot import create_rdkit_mol - import wandb - str = [mol_str[j][0] for j in range(len(mol_str))] - pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], i): create_rdkit_mol(str, pos) })''' - - # valid - if i % 10 == 0 and opts.nn: - _, (valid_vals, _, _) = valf(w, val_state, opts.normal, opts.nn) - print("validation:") - str = "error=" + "".join(["%.7f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(1, opts.val_bs)]) + " [eV]" - print(str) - print() - if opts.wandb: - dct = {} - for i in range(1, opts.val_bs): - dct['valid_l%i'%i ] = valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i] - wandb.log(dct) - - if opts.nn and opts.wandb and i > 0: - current_lr = custom_schedule(i, d_model, 4000) - wandb.log({'lr': current_lr}) - - if opts.bs == 1: pbar.set_description("error=%.7f [eV] (%.7f %.7f) "%(np.mean(val*HARTREE_TO_EV-state.pyscf_E), val*HARTREE_TO_EV, state.pyscf_E)) - else: - if opts.wandb: - wandb.log( - {'l1': vals[0]*HARTREE_TO_EV-state.pyscf_E[0], - 'l2': vals[1]*HARTREE_TO_EV-state.pyscf_E[1]}) - - str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" - #str += "E_xc=" + "".join(["%.7f "%(E_xc[i]*HARTREE_TO_EV) for i in range(opts.bs)]) + " [eV]" - try: - mins = np.minimum(mins, np.abs(vals*HARTREE_TO_EV - state.pyscf_E[:, 0])) - if np.max(mins) < 1e-5: - break - str += " best=" + "".join(["%.7f "%(mins[i]) for i in range(2)]) + " [eV]" - except: - pass - pbar.set_description(str) - if i == 0: print("") - - if val < min_val: - min_val = val - min_dm = density_matrix - val, density_matrix = min_val, min_dm @@ -790,7 +835,7 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T init = np.eye(N)[:, :n_electrons_half] mask = np.ones((1, n_electrons_half)) else: - assert pad_electrons > n_electrons_half + assert pad_electrons > n_electrons_half, (pad_electrons, n_electrons_half) init = np.eye(N)[:, :pad_electrons] mask = np.zeros((1, pad_electrons)) mask[:, :n_electrons_half] = 1 @@ -848,7 +893,7 @@ def e(x): return np.expand_dims(x, axis=0) else: raise Exception() ao_types = np.array(types) pos = np.concatenate(pos) - + state = IterationState( diffs_ERI = np.zeros((1,1)), main_grid_AO = np.zeros((1,1)), @@ -934,8 +979,6 @@ def cats(dcs): # Create a new dataclass instance with the concatenated fields return IterationState(**concatenated_fields) - - def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): # Electronic part of RHF/RKS gradients dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. @@ -1025,7 +1068,7 @@ def callback(envs): pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV pyscf_hlgaps.append(hl_gap_hartree) - print("\rPYSCF: ", pyscf_energies[-1] , "[eV]", end="") + print("PYSCF: ", pyscf_energies[-1] , "[eV]") mf.callback = callback mf.kernel() print("") @@ -1103,8 +1146,11 @@ def reference(mol_str, opts): parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") parser.add_argument('-waters', action="store_true") + parser.add_argument('-alanine', action="store_true") parser.add_argument('-states', type=int, default=1) - parser.add_argument('-wiggle_var', type=float, default=0.3, help="wiggle N(0, wiggle_var)") + parser.add_argument('-workers', type=int, default=5) # set to 0 for faster pyscf precompute. + # do noise schedule, start small slowly increase + parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") # models @@ -1112,6 +1158,9 @@ def reference(mol_str, opts): parser.add_argument('-tiny', action="store_true") parser.add_argument('-small', action="store_true") parser.add_argument('-base', action="store_true") + parser.add_argument('-medium', action="store_true") + parser.add_argument('-large', action="store_true") + parser.add_argument('-xlarge', action="store_true") opts = parser.parse_args() if opts.tiny or opts.small or opts.base: opts.nn = True @@ -1160,6 +1209,33 @@ def reference(mol_str, opts): ["H", ( 1.894, 0.486, 0.335)], ["H", ( 0.451, 0.165, -0.083)]]] + elif opts.alanine: + mol_strs = [[ + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) exit() diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index 9e6a5ab0..f4f9d693 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -12,6 +12,7 @@ def rand(rng, f, shape, **kwargs): return rng, f(rng1, shape, **kwargs) def linear_init_uniform(rng: jax.random.KeyArray, in_features: int, out_features: int): + # todo: init as kaparthy params = ParamsDict() rnd_range = 1 / in_features**0.5 rng, params.weight = rand( rng, jax.random.uniform, (in_features, out_features), minval=-rnd_range, maxval=rnd_range,) @@ -57,6 +58,10 @@ def transformer_init( total_params += np.prod(params.embeddings.shape) print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) + rng, params.project_positions, shape = linear_init_uniform(rng, 12, d_model) + total_params += np.prod(shape) + print("%26s %26s %26s"%("params.project_positions",shape, np.prod(shape))) + # For transformer layers params.layers = [] for i in range(n_layers): @@ -101,10 +106,16 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm - all_pairs = jnp.linalg.norm( position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) + all_pairs = jnp.linalg.norm(position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) + + # inspired by 3d point cloud transformers; + # nspired by andrew: use trigonometric functions as feature transformations + position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) + positions = linear(params.project_positions, position) # L x Dm*4 # Add (learned) positional encodings - x = jnp.concatenate([embeddings[:, :-3], position*10], -1) # seems positions ignored, scale larger (nn has to learn unit anyway) + #x = jnp.concatenate([embeddings[:, :-3], position], -1) + x = embeddings + positions #x = embeddings L, dm = x.shape @@ -124,9 +135,9 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) # do like graphformer and append position here? - if layer_num < 6: # doesn't look like it helps - score += H_core - score += all_pairs + #if layer_num < 6: # doesn't look like it helps + # score += H_core + # score += all_pairs attn = jax.nn.softmax(score , axis=1) x = x + (attn @ v).reshape(L, Dm) @@ -144,8 +155,6 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. x = x + t2 return score[0] # take first head - #print(score.shape, x.shape, x[:L,:L].shape) - #return x[:L, :L] #score import types @@ -197,3 +206,4 @@ def labels_aux(cls, path, obj): def items(self, path = ''): yield from self.labels_aux(path, self) + From 7c0813fc576c5b3062b5a354147f4673b0c39cf9 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Wed, 20 Dec 2023 20:02:04 +0000 Subject: [PATCH 12/22] Removed FF optimization of alanine. Plotting parameters. Added 300M 700M and 1.5B transformer. --- pyscf_ipu/direct/train.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index f19ab6ca..b84f0af3 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -144,7 +144,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, molecule = Chem.MolFromPDBFile(pdb_file, removeHs=False) # tried reading from pdb_block; caused parallel dataloader pickle to break. AllChem.EmbedMolecule(molecule) - AllChem.UFFOptimizeMolecule(molecule) + #AllChem.UFFOptimizeMolecule(molecule) phi_atoms = [4, 6, 8, 14] # indices for phi dihedral psi_atoms = [6, 8, 14, 16] # indices for psi dihedral''' @@ -163,10 +163,10 @@ def get_atom_positions(mol): for j in range(22): mol_str[j][1] = tuple(pos[j]) - if iteration == 0: + if iteration == 0 and opts.wandb: from plot import create_rdkit_mol import wandb - wandb.log({"mol_valid=%s_angle=%f"%(validation, angle): create_rdkit_mol(str, pos[:21]) }) + wandb.log({"mol_valid=%s_angle=%f"%(validation, angle): create_rdkit_mol(str, pos[:22]) }) if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] @@ -408,7 +408,7 @@ def nanoDFT(mol_str, opts): n_layers = 12 if opts.medium: d_model= 1024 - n_heads = 12 + n_heads = 16 n_layers = 24 if opts.large: d_model= 1280 @@ -429,6 +429,8 @@ def nanoDFT(mol_str, opts): d_ff =d_model*4, ) + if opts.wandb: wandb.log({"total_parameters": total_params }) + if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer @@ -561,8 +563,8 @@ def __next__(self): return self.item # scale by global batch size. global_batch_size = len(states)*opts.bs - wandb.log({"global_batch_size": global_batch_size}) accumulated_grad = jax.tree_map(lambda x: x / global_batch_size, accumulated_grad) + if opts.wandb: wandb.log({"global_batch_size": global_batch_size}) # plot grad norm if iteration % 10 == 0: From e53bf79cf01e2e57887723eeeb532b75ff22b993 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Wed, 20 Dec 2023 20:06:07 +0000 Subject: [PATCH 13/22] added total param count and changed molecule naming for alanine. --- pyscf_ipu/direct/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index b84f0af3..bec7aa19 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -166,7 +166,7 @@ def get_atom_positions(mol): if iteration == 0 and opts.wandb: from plot import create_rdkit_mol import wandb - wandb.log({"mol_valid=%s_angle=%f"%(validation, angle): create_rdkit_mol(str, pos[:22]) }) + wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) }) if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] From ca781318bf5dedbd1c8be600329c1f1c1b0421d8 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 21 Dec 2023 12:30:46 +0000 Subject: [PATCH 14/22] forgot to add plot code --- pyscf_ipu/direct/plot.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 pyscf_ipu/direct/plot.py diff --git a/pyscf_ipu/direct/plot.py b/pyscf_ipu/direct/plot.py new file mode 100644 index 00000000..430dad3a --- /dev/null +++ b/pyscf_ipu/direct/plot.py @@ -0,0 +1,21 @@ +import wandb +from rdkit import Chem +import rdkit +import rdkit.Chem +import rdkit.Chem.AllChem +from rdkit.Geometry.rdGeometry import Point3D +from rdkit.Chem import AllChem +import numpy as np + +def create_rdkit_mol(atom_types, atom_positions): + mol = Chem.RWMol() + for atom_type in atom_types: + atom = Chem.Atom(atom_type) + mol.AddAtom(atom) + conf = Chem.Conformer(len(atom_types)) + for i, pos in enumerate(atom_positions): + if isinstance(pos, np.ndarray): pos = pos.tolist() + point = Point3D(*pos) + conf.SetAtomPosition(i, point) + mol.AddConformer(conf) + return wandb.Molecule.from_rdkit(mol, convert_to_3d_and_optimize=False) From 2002e4a78e24c20979043c02d456429c389a62a4 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 21 Dec 2023 13:21:02 +0000 Subject: [PATCH 15/22] rotate both dipeptide angles; batch rotation can be controlled with -rotate_deg. --- pyscf_ipu/direct/train.py | 69 ++++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 26 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index bec7aa19..c3825387 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -88,7 +88,6 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ -1, -1, -1, -1, -1, -1 - if opts.benzene and opts.nn: pad_electrons = 30 if opts.hydrogens: pad_diff_ERIs = 5000 @@ -96,7 +95,6 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_grid_AO = 2200 pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 - if opts.qm9: pad_electrons=45 pad_diff_ERIs=120000 @@ -104,15 +102,15 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_grid_AO=50000 pad_nonzero_distinct_ERI=400000 pad_sparse_diff_grid=400000 - if opts.alanine: + # todo: (adam) the ERI padding may change when rotating molecule more! multiplier = 1.2 - pad_electrons=70 - pad_diff_ERIs=int(180000*1.2) - pad_distinct_ERIs=int(600000*1.2) - pad_grid_AO=int(70000*1.2) - pad_nonzero_distinct_ERI=int(600000*1.2) - pad_sparse_diff_grid=int(1000000*1.2) + pad_electrons = 70 + pad_diff_ERIs = int(180000*1.2) + pad_distinct_ERIs = int(600000) + pad_grid_AO = int(70000*1.2/4) + pad_nonzero_distinct_ERI = int(600000*1.2) + pad_sparse_diff_grid =int(1000000*1.2) if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] @@ -120,14 +118,8 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, mol = build_mol(mol_str, opts.basis) pad_electrons = min(pad_electrons, mol.nao_nr()) - t0 = time.time() - state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) - c, w = state.grid_coords, state.grid_weights - # Set seed to ensure different rotation; initially all workers did same rotation! np.random.seed(mol_idx) - p = np.array(mol_str[0][1]) - states = [state] if opts.waters: water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) @@ -135,7 +127,12 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if opts.qm9: atoms = np.array([mol_str[i][1] for i in range(0,3)]) - for iteration in range(bs-1): + if opts.alanine: + phi, psi = [float(a) for a in np.random.uniform(0, 360, 2)] + angles = [] + + states = [] + for iteration in range(bs): if opts.alanine: from rdkit import Chem @@ -156,12 +153,16 @@ def get_atom_positions(mol): str = [mol_str[j][0] for j in range(len(mol_str))] pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - # todo: explore larger space than 10 degree on a single bond! - angle = float(np.random.uniform(0, 10, 1)) - 5 - AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle) + # todo: save=wandb.log({"pair": angle1, angle2, NN_energy ) (rotation, NN_energy) for train/val molecule (for val also save PySCF energy) + # only saving angles (angle not necessarily paired up with energy) + AllChem.SetDihedralDeg(molecule.GetConformer(), *phi_atoms, phi) + angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? + angle = angle % 360 + AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle ) pos = get_atom_positions(molecule) + angles.append((phi, angle)) - for j in range(22): mol_str[j][1] = tuple(pos[j]) + for j in range(len(mol_str)): mol_str[j][1] = tuple(pos[j]) if iteration == 0 and opts.wandb: from plot import create_rdkit_mol @@ -203,11 +204,21 @@ def get_atom_positions(mol): pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) }) - # when profiling create fake molecule to skip waiting - if iteration == 0 or not opts.prof: stateB = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) - - states.append(stateB) - + if iteration == 0: + state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) + c, w = state.grid_coords, state.grid_weights + elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting + state = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) + + states.append(state) + + # If we add energy here we get plot basically! + if opts.alanine and opts.wandb: + for phi, psi in angles: + if not validation: + wandb.log({"phi_train": phi , "psi_train": psi}) + else: + wandb.log({"phi_valid": phi, "psi_valid": psi}) state = cats(states) N = state.N[0] @@ -323,6 +334,7 @@ def get_atom_positions(mol): if opts.wandb: wandb.log({"pad_distinct_ERIs": pad/state.nonzero_distinct_ERI.shape[2]}) if pad_grid_AO != -1: + prev_size = state.grid_AO.shape[2] assert state.grid_AO.shape[2] == state.grid_weights.shape[1] assert state.grid_AO.shape[2] == state.grid_coords.shape[1] assert state.grid_AO.shape[2] == state.main_grid_AO.shape[1] @@ -335,7 +347,11 @@ def get_atom_positions(mol): state.main_grid_AO = np.pad(state.main_grid_AO, ((0,0),(0,pad),(0,0))) state.diffs_grid_AO = np.pad(state.diffs_grid_AO, ((0,0),(0,0),(0,pad),(0,0))) - if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) + if opts.wandb: + wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2], + "pad_grid_AO_prev": prev_size, + "pad_grid_AO_pad": pad, + "pad_grid_AO_target": pad_grid_AO}) indxs = np.abs(state.nonzero_distinct_ERI ) > 1e-9 @@ -1154,6 +1170,7 @@ def reference(mol_str, opts): # do noise schedule, start small slowly increase parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") + parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") # models parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") From d18c698ef808cea791dbb275245229426501e45e Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Thu, 21 Dec 2023 18:56:03 +0000 Subject: [PATCH 16/22] work. --- pyscf_ipu/direct/train.py | 116 ++++++++++++++++++++++---------- pyscf_ipu/direct/transformer.py | 22 +++++- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index c3825387..92f9f8ca 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -11,6 +11,7 @@ from transformer import transformer, transformer_init import pandas as pd import math +from functools import partial cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 @@ -22,9 +23,9 @@ def T(x): return jnp.transpose(x, (0,2,1)) def dm_energy(W: BxNxK, state, normal, nn): if nn: W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ - W, state.ao_types, state.pos, state.H_core) - # using eigh + H_core + L_inv made training more stable. - # todo: investigate whether subset is sufficient (e.g. H_core+qr) or we need all. + W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + + W = W.astype(jnp.float64) L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO @@ -88,6 +89,9 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ -1, -1, -1, -1, -1, -1 + + max_pad_electrons, max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid = \ + -1, -1, -1, -1, -1, -1 if opts.benzene and opts.nn: pad_electrons = 30 if opts.hydrogens: pad_diff_ERIs = 5000 @@ -104,13 +108,13 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_sparse_diff_grid=400000 if opts.alanine: # todo: (adam) the ERI padding may change when rotating molecule more! - multiplier = 1.2 pad_electrons = 70 - pad_diff_ERIs = int(180000*1.2) - pad_distinct_ERIs = int(600000) - pad_grid_AO = int(70000*1.2/4) - pad_nonzero_distinct_ERI = int(600000*1.2) - pad_sparse_diff_grid =int(1000000*1.2) + # padding is estimated/printed when running; copy those numbers to the list below. + padding_estimate = [ 210745, 219043, 18084, 193830, 1105268] + # add 10% + padding_estimate = [int(a*1.1) for a in padding_estimate] + # name variables correctly + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] @@ -154,7 +158,7 @@ def get_atom_positions(mol): pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) # todo: save=wandb.log({"pair": angle1, angle2, NN_energy ) (rotation, NN_energy) for train/val molecule (for val also save PySCF energy) - # only saving angles (angle not necessarily paired up with energy) + # only saving angles (angle not paired up with energy) AllChem.SetDihedralDeg(molecule.GetConformer(), *phi_atoms, phi) angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? angle = angle % 360 @@ -188,8 +192,8 @@ def get_atom_positions(mol): elif opts.qm9: # todo: find dihedral to rotate over similar to alanine dipeptide. - # rotate first three atoms around their center of mass - # this may break molecule; may need to do this around a bond or smth + # broken; rotate first three atoms around their center of mass + # this breaks molecule; should use dihedral angle as done with the dipeptide. #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] #center = atoms.mean(axis=0) #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center @@ -213,6 +217,8 @@ def get_atom_positions(mol): states.append(state) # If we add energy here we get plot basically! + # todo: save and store in training loop, then we can match with energy + # can't get to work in wandb, but can just use download api and the plot. if opts.alanine and opts.wandb: for phi, psi in angles: if not validation: @@ -270,7 +276,6 @@ def get_atom_positions(mol): main_grid_AO = state.grid_AO[:1] diffs_grid_AO = main_grid_AO - state.grid_AO rows, cols = np.nonzero(np.max(diffs_grid_AO[:, 0]!=0, axis=0)) - sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] # use the same sparsity pattern across a batch. @@ -293,6 +298,7 @@ def get_atom_positions(mol): state.indxs=diff_indxs state.diffs_ERI=diff_ERIs else: + max_pad_diff_ERIs = diff_ERIs.shape[2] # pad ERIs with 0 and indices with -1 so they point to 0. assert diff_indxs.shape[1] == diff_ERIs.shape[2] pad = pad_diff_ERIs - diff_indxs.shape[1] @@ -311,6 +317,7 @@ def get_atom_positions(mol): state.diffs_grid_AO = diffs_grid_AO if pad_sparse_diff_grid != -1: + max_pad_sparse_diff_grid = state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] pad = pad_sparse_diff_grid - state.rows.shape[0] @@ -325,6 +332,7 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[:1] state.nonzero_indices = np.expand_dims(state.nonzero_indices, axis=0) if pad_distinct_ERIs != -1: + max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) @@ -334,6 +342,8 @@ def get_atom_positions(mol): if opts.wandb: wandb.log({"pad_distinct_ERIs": pad/state.nonzero_distinct_ERI.shape[2]}) if pad_grid_AO != -1: + max_pad_grid_AO = state.grid_AO.shape[2] + prev_size = state.grid_AO.shape[2] assert state.grid_AO.shape[2] == state.grid_weights.shape[1] assert state.grid_AO.shape[2] == state.grid_coords.shape[1] @@ -367,6 +377,8 @@ def get_atom_positions(mol): state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) if pad_nonzero_distinct_ERI != -1: + max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] + assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] assert pad >= 0, (pad_nonzero_distinct_ERI, state.nonzero_distinct_ERI.shape[2]) @@ -375,9 +387,12 @@ def get_atom_positions(mol): if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) - - import copy - return copy.deepcopy(state) + #import copy + #return copy.deepcopy(state) + state.pad_sizes = np.array([ + max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, + max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid]) + return state def nanoDFT(mol_str, opts): @@ -428,11 +443,11 @@ def nanoDFT(mol_str, opts): n_layers = 24 if opts.large: d_model= 1280 - n_heads = 12 + n_heads = 16 n_layers = 36 if opts.xlarge: d_model= 1600 - n_heads = 12 + n_heads = 25 n_layers = 48 if opts.nn: @@ -444,6 +459,7 @@ def nanoDFT(mol_str, opts): n_heads =n_heads, d_ff =d_model*4, ) + params = params.to_float32() if opts.wandb: wandb.log({"total_parameters": total_params }) @@ -456,12 +472,25 @@ def nanoDFT(mol_str, opts): # [ ] weight initialization def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): - if it < warmup_iters: return learning_rate * it / warmup_iters + #return learning_rate * it / warmup_iters # to allow jax jit? + # allow jax jit + '''if it < warmup_iters: return learning_rate * it / warmup_iters if it > lr_decay_iters: return min_lr decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) - return min_lr + coeff * (learning_rate - min_lr) + return min_lr + coeff * (learning_rate - min_lr)''' + #if it < warmup_iters: return learning_rate * it / warmup_iters + cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters + cond2 = (it > lr_decay_iters) * min_lr + + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) + cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) + return cond1 + cond2 + cond3 + + + adam = optax.chain( optax.clip_by_global_norm(1), @@ -550,51 +579,67 @@ def __next__(self): return self.item vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) adam_state = adam.init(w) + w, adam_state = jax.device_put(w), jax.device_put(adam_state) + + @partial(jax.jit, backend=opts.backend) + def update(w, adam_state, accumulated_grad): + accumulated_grad = jax.tree_map(lambda x: x / global_batch_size, accumulated_grad) + updates, adam_state = adam.update(accumulated_grad, adam_state, w) + w = optax.apply_updates(w, updates) + return w, adam_state if opts.wandb: if not opts.nn: total_params = -1 wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) - min_val, min_dm, mins, val_state, valid_str, step = 0, 0, np.ones(opts.bs)*1e6, None, "", 0 t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 - states = [] + paddings = [] + states = [] for iteration, state in enumerate(pbar): if iteration == 0: summary(state) + state = jax.device_put(state) + + # estimate max padding. + if iteration < 100: + paddings.append(state.pad_sizes.reshape(1, -1)) + _paddings = np.concatenate(paddings, axis=0) + print(np.max(_paddings, 0)) + dct = {} dct["iteraton"] = iteration load_time, t0 = time.time()-t0, time.time() - states = states[-opts.grad_acc:] + [state] + states = states[-opts.grad_acc+1:] + [state] accumulated_grad = None if len(states) < 50: print(len(states)) for j, state in enumerate(states): + print(".\t", end="", flush=True) if j == 0: _t0 =time.time() (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) + print(",", end="", flush=True) if j == 0: time_step1 = time.time()-_t0 accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) + train_time, t0 = time.time()-t0, time.time() # scale by global batch size. global_batch_size = len(states)*opts.bs - accumulated_grad = jax.tree_map(lambda x: x / global_batch_size, accumulated_grad) if opts.wandb: wandb.log({"global_batch_size": global_batch_size}) + w, adam_state = update(w, adam_state, accumulated_grad) # plot grad norm - if iteration % 10 == 0: - for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) - - updates, adam_state = adam.update(accumulated_grad, adam_state, params) - w = optax.apply_updates(w, updates) - train_time, t0 = time.time()-t0, time.time() + #if iteration % 10 == 0: + # for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) + update_time, t0 = time.time()-t0, time.time() if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) else: - pbar.set_description("train=".join(["%.5f"%i for i in vals[:2]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, val_time, plot_time)) + pbar.set_description("train=".join(["%.2f"%i for i in vals[:1]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) if opts.wandb: dct["time_load"] = load_time @@ -613,8 +658,8 @@ def __next__(self): return self.item plot_time, t0 = time.time()-t0, time.time() - if opts.nn:# and i % 5 == 0: - if val_state is None: val_state = val_qm9[0] + if opts.nn:# and iteration % 10 == 0: + if val_state is None: val_state = jax.device_put(val_qm9[0]) _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) lr = custom_schedule(iteration) valid_str = "lr=%.3e"%lr + "val_error=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" @@ -627,7 +672,7 @@ def __next__(self): return self.item dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) dct["scheduled_lr"] = custom_schedule(iteration) - wandb.log(dct) + if opts.wandb: wandb.log(dct) valid_time, t0 = time.time()-t0, time.time() val, density_matrix = min_val, min_dm @@ -670,6 +715,7 @@ class IterationState: cols: np.array pos: np.array ao_types: np.array + pad_sizes: np.array from pyscf.data.elements import charge as elements_proton from pyscf.dft import gen_grid, radi @@ -911,6 +957,7 @@ def e(x): return np.expand_dims(x, axis=0) else: raise Exception() ao_types = np.array(types) pos = np.concatenate(pos) + pad_sizes = np.zeros(1) state = IterationState( diffs_ERI = np.zeros((1,1)), @@ -936,6 +983,7 @@ def e(x): return np.expand_dims(x, axis=0) pyscf_E=e(pyscf_E[-1:]), N=e(mol.nao_nr()), mask=e(mask), + pad_sizes=e(pad_sizes), ) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index f4f9d693..df6a162a 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -119,10 +119,9 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. #x = embeddings L, dm = x.shape - # Apply the transformer layers - for layer_num, layer in enumerate(params.layers): + def block(x, layer_num, layers): # Layer-normalize embeddings - t1 = vmap(standardize)(embeddings) + x = vmap(standardize)(x) t1 = elementwise_linear(layer.norm_self_attn, x) # L x Dm L, Dm = t1.shape @@ -153,7 +152,14 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. # Add this layer's contribution into embeddings x = x + t2 + return x, score + + # Apply the transformer layers + # todo: cut jit time by making this jax.lax.foriloop + for layer_num, layer in enumerate(params.layers): + x, score = jax.checkpoint(block)(x, layer_num, layer) + return score[0] # take first head @@ -206,4 +212,14 @@ def labels_aux(cls, path, obj): def items(self, path = ''): yield from self.labels_aux(path, self) + def to_float32(self): + def convert_to_float32(x): + if isinstance(x, jnp.ndarray) and x.dtype == jnp.float64: + return x.astype(jnp.float32) + return x + + # Create a new ParamsDict instance with converted arrays + new_dict = jax.tree_map(convert_to_float32, self.__dict__) + return ParamsDict(**new_dict) + self.__dict__ = jax.tree_map(convert_to_float32, self.__dict__) From 0c13c5120ddada422ad6ec4affe2089ccaed90c2 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Sat, 30 Dec 2023 12:20:19 +0000 Subject: [PATCH 17/22] update each step no accumulate. make wandb step = num train step by removing wandb.log calls. --- pyscf_ipu/direct/train.py | 65 +++++++++++++++------------------ pyscf_ipu/direct/transformer.py | 5 +-- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 92f9f8ca..a8ebe928 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -83,8 +83,6 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_sparse_diff_grid=200000, mol_idx=42, ): - if opts.wandb: import wandb - # pad molecule if using nn. if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ @@ -110,7 +108,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, # todo: (adam) the ERI padding may change when rotating molecule more! pad_electrons = 70 # padding is estimated/printed when running; copy those numbers to the list below. - padding_estimate = [ 210745, 219043, 18084, 193830, 1105268] + padding_estimate = [210745, 219043, 18084, 193830, 1105268] # add 10% padding_estimate = [int(a*1.1) for a in padding_estimate] # name variables correctly @@ -168,10 +166,10 @@ def get_atom_positions(mol): for j in range(len(mol_str)): mol_str[j][1] = tuple(pos[j]) - if iteration == 0 and opts.wandb: + '''if iteration == 0 and opts.wandb: from plot import create_rdkit_mol import wandb - wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) }) + wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) })''' if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] @@ -182,12 +180,12 @@ def get_atom_positions(mol): mol_str[4][1] = tuple(water_xyz[1]) mol_str[5][1] = tuple(water_xyz[2]) - if opts.wandb and iteration == 0: + '''if opts.wandb and iteration == 0: from plot import create_rdkit_mol import wandb str = [mol_str[j][0] for j in range(len(mol_str))] pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) }) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' elif opts.qm9: # todo: find dihedral to rotate over similar to alanine dipeptide. @@ -201,12 +199,12 @@ def get_atom_positions(mol): mol_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) mol_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) - if opts.wandb and iteration == 0: + '''if opts.wandb and iteration == 0: from plot import create_rdkit_mol import wandb str = [mol_str[j][0] for j in range(len(mol_str))] pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) }) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' if iteration == 0: state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) @@ -219,12 +217,12 @@ def get_atom_positions(mol): # If we add energy here we get plot basically! # todo: save and store in training loop, then we can match with energy # can't get to work in wandb, but can just use download api and the plot. - if opts.alanine and opts.wandb: + '''if opts.alanine and opts.wandb: for phi, psi in angles: if not validation: wandb.log({"phi_train": phi , "psi_train": psi}) else: - wandb.log({"phi_valid": phi, "psi_valid": psi}) + wandb.log({"phi_valid": phi, "psi_valid": psi})''' state = cats(states) N = state.N[0] @@ -306,7 +304,7 @@ def get_atom_positions(mol): state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros - if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) state.rows=rows state.cols=cols @@ -326,7 +324,7 @@ def get_atom_positions(mol): state.cols = np.pad(state.cols, (0,pad)) state.sparse_diffs_grid_AO = np.pad(state.sparse_diffs_grid_AO, ((0,0),(0,pad))) - if opts.wandb: wandb.log({"pad_sparse_diff_grid": pad/state.sparse_diffs_grid_AO.shape[1]}) + #if opts.wandb: wandb.log({"pad_sparse_diff_grid": pad/state.sparse_diffs_grid_AO.shape[1]}) #state.grid_AO = state.grid_AO[:1] state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[:1] @@ -339,7 +337,7 @@ def get_atom_positions(mol): state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0), (0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0), (0,0), (0, pad))) # pad zeros - if opts.wandb: wandb.log({"pad_distinct_ERIs": pad/state.nonzero_distinct_ERI.shape[2]}) + #if opts.wandb: wandb.log({"pad_distinct_ERIs": pad/state.nonzero_distinct_ERI.shape[2]}) if pad_grid_AO != -1: max_pad_grid_AO = state.grid_AO.shape[2] @@ -357,11 +355,11 @@ def get_atom_positions(mol): state.main_grid_AO = np.pad(state.main_grid_AO, ((0,0),(0,pad),(0,0))) state.diffs_grid_AO = np.pad(state.diffs_grid_AO, ((0,0),(0,0),(0,pad),(0,0))) - if opts.wandb: - wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2], - "pad_grid_AO_prev": prev_size, - "pad_grid_AO_pad": pad, - "pad_grid_AO_target": pad_grid_AO}) + #if opts.wandb: + # wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2], + # "pad_grid_AO_prev": prev_size, + # "pad_grid_AO_pad": pad, + # "pad_grid_AO_target": pad_grid_AO}) indxs = np.abs(state.nonzero_distinct_ERI ) > 1e-9 @@ -385,10 +383,8 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) - if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) + #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) - #import copy - #return copy.deepcopy(state) state.pad_sizes = np.array([ max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid]) @@ -461,7 +457,6 @@ def nanoDFT(mol_str, opts): ) params = params.to_float32() - if opts.wandb: wandb.log({"total_parameters": total_params }) if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer @@ -583,7 +578,7 @@ def __next__(self): return self.item @partial(jax.jit, backend=opts.backend) def update(w, adam_state, accumulated_grad): - accumulated_grad = jax.tree_map(lambda x: x / global_batch_size, accumulated_grad) + accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) updates, adam_state = adam.update(accumulated_grad, adam_state, w) w = optax.apply_updates(w, updates) return w, adam_state @@ -612,7 +607,6 @@ def update(w, adam_state, accumulated_grad): load_time, t0 = time.time()-t0, time.time() states = states[-opts.grad_acc+1:] + [state] - accumulated_grad = None if len(states) < 50: print(len(states)) @@ -622,13 +616,14 @@ def update(w, adam_state, accumulated_grad): (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) print(",", end="", flush=True) if j == 0: time_step1 = time.time()-_t0 - accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - train_time, t0 = time.time()-t0, time.time() - # scale by global batch size. + w, adam_state = update(w, adam_state, grad) + + # todo: rename global_batch_size = len(states)*opts.bs - if opts.wandb: wandb.log({"global_batch_size": global_batch_size}) - w, adam_state = update(w, adam_state, accumulated_grad) + if opts.wandb: dct["global_batch_size"] = global_batch_size + + train_time, t0 = time.time()-t0, time.time() # plot grad norm #if iteration % 10 == 0: @@ -658,7 +653,7 @@ def update(w, adam_state, accumulated_grad): plot_time, t0 = time.time()-t0, time.time() - if opts.nn:# and iteration % 10 == 0: + if opts.nn and (iteration < 1000 or iteration % 10 == 0): if val_state is None: val_state = jax.device_put(val_qm9[0]) _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) lr = custom_schedule(iteration) @@ -1225,11 +1220,11 @@ def reference(mol_str, opts): parser.add_argument('-tiny', action="store_true") parser.add_argument('-small', action="store_true") parser.add_argument('-base', action="store_true") - parser.add_argument('-medium', action="store_true") - parser.add_argument('-large', action="store_true") - parser.add_argument('-xlarge', action="store_true") + parser.add_argument('-medium', action="store_true") + parser.add_argument('-large', action="store_true") + parser.add_argument('-xlarge', action="store_true") opts = parser.parse_args() - if opts.tiny or opts.small or opts.base: opts.nn = True + if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True args_dict = vars(opts) print(args_dict) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index df6a162a..b2439da7 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -154,15 +154,14 @@ def block(x, layer_num, layers): x = x + t2 return x, score - # Apply the transformer layers # todo: cut jit time by making this jax.lax.foriloop for layer_num, layer in enumerate(params.layers): - x, score = jax.checkpoint(block)(x, layer_num, layer) + if layer_num % 2 == 0: x, score = jax.checkpoint(block)(x, layer_num, layer) + else: x, score = block(x, layer_num, layer) return score[0] # take first head - import types import json import jax From d14fd42e90b22ef155fd029c982f186c79ee6f2d Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Tue, 2 Jan 2024 18:13:24 +0000 Subject: [PATCH 18/22] work. --- pyscf_ipu/direct/sparse_symmetric_ERI.py | 212 +++++++++++++- pyscf_ipu/direct/train.py | 335 ++++++++++++++++++----- pyscf_ipu/direct/transformer.py | 166 +++++++++-- 3 files changed, 602 insertions(+), 111 deletions(-) diff --git a/pyscf_ipu/direct/sparse_symmetric_ERI.py b/pyscf_ipu/direct/sparse_symmetric_ERI.py index f7646c5d..24ba84f6 100644 --- a/pyscf_ipu/direct/sparse_symmetric_ERI.py +++ b/pyscf_ipu/direct/sparse_symmetric_ERI.py @@ -2,6 +2,7 @@ import pyscf import numpy as np import jax +jax.config.update('jax_enable_x64', True) import jax.numpy as jnp from functools import partial from icecream import ic @@ -12,10 +13,16 @@ def get_i_j(val): j = (((val - i) - (i**2 - val))//2) return i, j -def ijkl(value, symmetry, N, f): - i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) +def _ijkl(value, symmetry, N, f): + #i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) + i, j, k, l = value[0], value[1], value[2], value[3] + return f(i,j,k,l,symmetry,N) +ijkl = jax.vmap(_ijkl, in_axes=(0, None, None, None)) + +def np_ijkl(value, symmetry, N, f): + #i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) + i, j, k, l = value[:, 0], value[:, 1], value[:, 2], value[:, 3] return f(i,j,k,l,symmetry,N) -ijkl = jax.vmap(ijkl, in_axes=(0, None, None, None)) def num_repetitions_fast(ij, kl): @@ -67,7 +74,7 @@ def _indices_func(i, j, k, l, symmetry, N): elif symmetry == 31: return l * N + i -def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm): +def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, foriloop): dm = dm.reshape(-1) diff_JK = jnp.zeros(dm.shape) N = int(np.sqrt(dm.shape[0])) @@ -80,7 +87,9 @@ def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm): update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) + Z = jnp.zeros((N**2,), dtype=dm.dtype) + # todo: how much faster if we precompute dm/ss indices? def iteration(symmetry, vals): diff_JK = vals is_K_matrix = (symmetry >= 8) @@ -89,18 +98,17 @@ def sequentialized_iter(i, vals): # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) # Trade-off: Using one function leads to smaller always-live memory. diff_JK = vals - indices = nonzero_indices[i].astype(np.int32) + indices = nonzero_indices[i]#.astype(np.int32) # eris = nonzero_distinct_ERI[i] dm_indices = ijkl(indices, symmetry+is_K_matrix*8, N, indices_func).reshape(-1, 1) #dm_values = jnp.take(dm, dm_indices, axis=0)[:, 0] # for our special case the 50 lines of code reduces to the one line below. dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) - dm_values = dm_values * eris ss_indices = ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. - diff_JK = diff_JK + jax.lax.scatter_add(jnp.zeros((N**2,)), + diff_JK = diff_JK + jax.lax.scatter_add(Z, ss_indices, dm_values, scatter_dnums, indices_are_sorted=True, unique_indices=False, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ *(-HYB_B3LYP/2)**is_K_matrix @@ -108,15 +116,193 @@ def sequentialized_iter(i, vals): return diff_JK batches = nonzero_indices.shape[0] - diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) - #for i in range(batches): - # diff_JK = sequentialized_iter(i, diff_JK) + + # forloop makes training slower but compile time faster. + if foriloop: + diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) + else: + for i in range(batches): + diff_JK = sequentialized_iter(i, diff_JK) return diff_JK - diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) - #for i in range(0, 16): - # diff_JK = iteration(i, diff_JK) + if foriloop: + diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + else: + for i in range(0, 16): + diff_JK = iteration(i, diff_JK) #diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) #return jax.lax.psum(diff_JK, axis_name="p") return diff_JK.reshape(N, N) + +def sparse_einsum(nonzero_distinct_ERI, precomputed_indices, dm, foriloop): + dm = dm.reshape(-1) + diff_JK = jnp.zeros(dm.shape) + N = int(np.sqrt(dm.shape[0])) + + dnums = jax.lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) + scatter_dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) + Z = jnp.zeros((N**2,), dtype=dm.dtype) + + def iteration(symmetry, vals): + diff_JK = vals + is_K_matrix = (symmetry >= 8) + + def sequentialized_iter(i, vals): + # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) + # Trade-off: Using one function leads to smaller always-live memory. + diff_JK = vals + eris = nonzero_distinct_ERI[i] + + #dm_values = jnp.take(dm, dm_indices, axis=0)[:, 0] # for our special case the 50 lines of code reduces to the one line below. + dm_indices = precomputed_indices[symmetry, i, 0] + ss_indices = precomputed_indices[symmetry, i, 1] + dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) + dm_values = dm_values * eris + + #ss_indices = ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) + # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. + diff_JK = diff_JK + jax.lax.scatter_add(Z, ss_indices, dm_values, + scatter_dnums, indices_are_sorted=True, unique_indices=False, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ + *(-HYB_B3LYP/2)**is_K_matrix + + return diff_JK + + batches = precomputed_indices.shape[1] + + # forloop makes training slower but compile time faster. + if foriloop: + diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) + else: + for i in range(batches): + diff_JK = sequentialized_iter(i, diff_JK) + return diff_JK + + if foriloop: + diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + else: + for i in range(0, 16): + diff_JK = iteration(i, diff_JK) + #diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + #return jax.lax.psum(diff_JK, axis_name="p") + return diff_JK.reshape(N, N) + + + +def precompute_indices(nonzero_indices, N): + + def iteration(symmetry): + is_K_matrix = (symmetry >= 8) + + def sequentialized_iter(i): + # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) + # Trade-off: Using one function leads to smaller always-live memory. + indices = nonzero_indices[i] + dm_indices = np_ijkl(indices, symmetry+is_K_matrix*8, N, _indices_func).reshape(-1, 1) + ss_indices = np_ijkl(indices, symmetry+8+is_K_matrix*8, N, _indices_func) .reshape(-1,1) + + return dm_indices, ss_indices + + batches = nonzero_indices.shape[0] + + # forloop makes training slower but compile time faster. + _indices = [None for _ in range(batches)] + for i in range(batches): + _indices[i] = sequentialized_iter(i) + return _indices + + _indices = [None for _ in range(16)] + for i in range(0, 16): + _indices[i] = iteration(i) + return np.array(_indices ) + +if __name__ == "__main__": + import time + import argparse + parser = argparse.ArgumentParser(prog='', description='', epilog='') + parser.add_argument('-backend', default="cpu"), + parser.add_argument('-natm', default=3), + parser.add_argument('-test', action="store_true") + parser.add_argument('-prof', action="store_true") + parser.add_argument('-batches', default=5) + parser.add_argument('-skip', action="store_true") + + args = parser.parse_args() + backend = args.backend + + natm = int(args.natm) + nipu = 1 + + start = time.time() + + mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm))) + #mol = pyscf.gto.Mole(atom="".join(f"C 0 {15.4*j} {15.4*i};" for i in range(1) for j in range(75))) + mol.build() + N = mol.nao_nr() + print("N %i"%mol.nao_nr()) + print("NxN:", (N**2, N**2)) + print("Naive operations: ", N**4*2/10**9, "[Giga]") + if not args.skip: dense_ERI = mol.intor("int2e_sph", aosym="s1") + distinct_ERI = mol.intor("int2e_sph", aosym="s8") + #distinct_ERI[np.abs(distinct_ERI)<1e-9] = 0 # zero out stuff + dm = pyscf.scf.hf.init_guess_by_minao(mol) + scale = HYB_B3LYP/2 + if not args.skip: + J = np.einsum("ijkl,ji->kl", dense_ERI, dm) + K = np.einsum("ijkl,jk->il", dense_ERI, dm) + truth = J - K / 2 * HYB_B3LYP + + nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) + nonzero_distinct_ERI = distinct_ERI[nonzero_indices]#.astype(np.float32) + print("Nonzero Operations:", nonzero_indices.size*8*2/10**9, "[Giga]") + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + nonzero_distinct_ERI = nonzero_distinct_ERI / rep + dm = dm.reshape(-1) + diff_JK = np.zeros(dm.shape) + + batches = int(args.batches) + remainder = nonzero_indices.shape[0] % (nipu*batches) + + if remainder != 0: + print(nipu*batches-remainder, ij.shape) + ij = np.pad(ij, ((0,nipu*batches-remainder))) + kl = np.pad(kl, ((0,nipu*batches-remainder))) + nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) + + ij = ij.reshape(nipu, batches, -1) + kl = kl.reshape(nipu, batches, -1) + nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1) + + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int32) + #nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16) + + #diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(nonzero_distinct_ERI, nonzero_indices, dm, args.backend) + diff_JK = jax.jit(sparse_symmetric_einsum, static_argnums=(3,), backend=backend)(nonzero_distinct_ERI[0], nonzero_indices[0], dm, args.backend) + #diff_JK = jax.jit(sparse_symmetric_einsum, backend=backend, static_argnums=(3,))(nonzero_distinct_ERI[0], nonzero_indices[0], dm, False) + + indices = precompute_indices(nonzero_indices[0], N) + print(np.max(indices)) # this is just N**2! + indices = indices.astype(np.int16) + print(np.max(indices)) + print(nonzero_distinct_ERI.nbytes/10**9, nonzero_indices.nbytes/10**9, indices.nbytes/10**9) + print(nonzero_distinct_ERI.shape, nonzero_indices.shape, indices.shape) + print(np.max(indices)) + + _diff_JK = jax.jit(sparse_einsum, static_argnums=(3,), backend=backend)(nonzero_distinct_ERI[0], indices, dm, args.backend) + + + if args.skip: + exit() + + diff_JK = diff_JK.reshape(N, N) + print(diff_JK.reshape(-1)[::51]) + print(truth.reshape(-1)[::51]) + print(np.max(np.abs(diff_JK.reshape(-1) - truth.reshape(-1)))) + print(np.max(np.abs(_diff_JK.reshape(-1) - truth.reshape(-1)))) + assert np.allclose(diff_JK, truth, atol=1e-6) + assert np.allclose(_diff_JK, truth, atol=1e-6) + print("PASSED!") + \ No newline at end of file diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index a8ebe928..30773376 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -1,3 +1,5 @@ +import os +os.environ['OMP_NUM_THREADS'] = '8' import jax jax.config.update('jax_enable_x64', True) import jax.numpy as jnp @@ -12,6 +14,7 @@ import pandas as pd import math from functools import partial +import pickle cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 @@ -26,10 +29,11 @@ def dm_energy(W: BxNxK, state, normal, nn): W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) W = W.astype(jnp.float64) + # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO - E_xc: B = exchange_correlation(density_matrix, state, normal) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO - diff_JK: BxNxN = JK(density_matrix, state, normal) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + E_xc: B = exchange_correlation(density_matrix, state, normal, opts.xc_f32) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) energy: float = jnp.sum(energies) return energy, (energies, E_xc, density_matrix, W) @@ -39,34 +43,68 @@ def sparse_mult(values, dm, state, gsize): prod = in_*values[:, None] return jax.ops.segment_sum(prod, state.rows, gsize) -def exchange_correlation(density_matrix: BxNxN, state, normal): +def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): _, _, gsize, N = state.grid_AO.shape B = density_matrix.shape[0] if normal: grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (B, N, N) = O(B gsize N^2) - rho = jnp.sum(grid_AO_dm * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + rho = jnp.sum(grid_AO_dm.reshape(B, 1, gsize, N) * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) else: - main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs - correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) - rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) - rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) - rho = rho_a - rho_b + if xc_f32: density_matrix.astype(jnp.float32) + if False: + main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs + correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) + rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) + rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) + rho = rho_a - rho_b + else: + grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (B, N, N) = O(B gsize N^2) + rho = jnp.sum(grid_AO_dm.reshape(B, 1, gsize, N) * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + rho = rho.astype(jnp.float64) + E_xc = jax.vmap(_b3lyp, in_axes=(0, None))(rho, EPSILON_B3LYP).reshape(B, gsize) E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) return E_xc -def JK(density_matrix, state, normal): +def JK(density_matrix, state, normal, jax_foriloop, eri_f32): if normal: J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) diff_JK = J - K / 2 * HYB_B3LYP else: - from sparse_symmetric_ERI import sparse_symmetric_einsum - diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0))(state.nonzero_distinct_ERI[0], state.nonzero_indices[0], density_matrix) - diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0))(state.diffs_ERI, state.indxs, density_matrix) + from sparse_symmetric_ERI import sparse_symmetric_einsum, sparse_einsum + + if eri_f32: density_matrix = density_matrix.astype(jnp.float32) + + '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.indxs, + density_matrix, + jax_foriloop + )''' + + diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.precomputed_nonzero_indices, + density_matrix, + jax_foriloop + ) + correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.precomputed_indxs, + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - correction - return diff_JK + return diff_JK.astype(jnp.float64) def nao(atom, basis): m = pyscf.gto.Mole(atom='%s 0 0 0; %s 0 0 1;'%(atom, atom), basis=basis) @@ -74,7 +112,8 @@ def nao(atom, basis): return m.nao_nr()//2 def batched_state(mol_str, opts, bs, wiggle_num=0, - do_pyscf=True, validation=False, + do_pyscf=True, validation=False, + extrapolate=False, pad_electrons=45, pad_diff_ERIs=50000, pad_distinct_ERIs=120000, @@ -97,23 +136,34 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_grid_AO = 2200 pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 + if opts.qm9: - pad_electrons=45 - pad_diff_ERIs=120000 + pad_electrons=60 + '''pad_diff_ERIs=120000 pad_distinct_ERIs=400000 pad_grid_AO=50000 pad_nonzero_distinct_ERI=400000 - pad_sparse_diff_grid=400000 + pad_sparse_diff_grid=400000''' + #padding_estimate = [37426, 149710, 17010, 140122, 138369] + padding_estimate = [48330, 163222, 17034, 159361, 139505] + padding_estimate = [int(a*1.1) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + if opts.alanine: # todo: (adam) the ERI padding may change when rotating molecule more! pad_electrons = 70 # padding is estimated/printed when running; copy those numbers to the list below. - padding_estimate = [210745, 219043, 18084, 193830, 1105268] + padding_estimate = [210745, 219043, 18084, 193830, 1105268] + # 213973 218912 18084 195723 1105847] # add 10% padding_estimate = [int(a*1.1) for a in padding_estimate] # name variables correctly pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + pad_diff_ERIs *= int(8/opts.eri_bs) + pad_distinct_ERIs *= int(8/opts.eri_bs) + pad_nonzero_distinct_ERI *= int(8/opts.eri_bs) + if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] @@ -122,15 +172,24 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, # Set seed to ensure different rotation; initially all workers did same rotation! np.random.seed(mol_idx) + if opts.waters: water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) if opts.qm9: atoms = np.array([mol_str[i][1] for i in range(0,3)]) + # pick random atom to permute (of the first 9 heavy ones) + atom_num = int(np.random.uniform(0, 8)) if opts.alanine: - phi, psi = [float(a) for a in np.random.uniform(0, 360, 2)] + # train on [-180, 180], validate [-180, 180] extrapolate [-360, 360]\[180, -180] + # todo: draw picture (in training loop) + if extrapolate: + phi, psi = [float(a) for a in np.random.uniform(180, 360, 2)] + else: + phi, psi = [float(a) for a in np.random.uniform(0, 180, 2)] + angles = [] states = [] @@ -159,7 +218,13 @@ def get_atom_positions(mol): # only saving angles (angle not paired up with energy) AllChem.SetDihedralDeg(molecule.GetConformer(), *phi_atoms, phi) angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? - angle = angle % 360 + + # todo: check math whether val/extra/train have uniform distribution on their respective domains. + if extrapolate: # make sure angle is in [] + angle = angle % 180 + 180 # angle should be in [180, 360] + else: + angle = angle % 180 # angle should be [0, 180] + AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle ) pos = get_atom_positions(molecule) angles.append((phi, angle)) @@ -189,15 +254,21 @@ def get_atom_positions(mol): elif opts.qm9: # todo: find dihedral to rotate over similar to alanine dipeptide. - # broken; rotate first three atoms around their center of mass # this breaks molecule; should use dihedral angle as done with the dipeptide. #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] #center = atoms.mean(axis=0) #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center - mol_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) - mol_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) - mol_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) + + # for extrapolation, do even more. + + if iteration == 0 and (validation or extrapolate): + pass + else: + #mol_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) + mol_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0, opts.wiggle_var, (3))) + #mol_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) + #mol_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) '''if opts.wandb and iteration == 0: from plot import create_rdkit_mol @@ -245,7 +316,7 @@ def get_atom_positions(mol): ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) - batches = 8 + batches = opts.eri_bs es = [] for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): nonzero_distinct_ERI = e[nonzero_indices] / rep @@ -265,9 +336,13 @@ def get_atom_positions(mol): j = np.pad(j, ((0,batches-remainder))) k = np.pad(k, ((0,batches-remainder))) l = np.pad(l, ((0,batches-remainder))) - nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int16) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int32) # todo: use int16 or int32 here? + state.nonzero_indices = nonzero_indices - state.nonzero_indices = nonzero_indices + # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. + # function below does this. + # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). + from sparse_symmetric_ERI import precompute_indices if opts.normal: diff_state = None else: @@ -292,9 +367,12 @@ def get_atom_positions(mol): diff_ERIs = diff_ERIs.reshape(bs, batches, -1) diff_indxs = diff_indxs.reshape(batches, -1, 4) + precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) + if pad_diff_ERIs == -1: state.indxs=diff_indxs state.diffs_ERI=diff_ERIs + assert False, "deal with precomputed_indxs; only added in else branch below" else: max_pad_diff_ERIs = diff_ERIs.shape[2] # pad ERIs with 0 and indices with -1 so they point to 0. @@ -303,6 +381,8 @@ def get_atom_positions(mol): assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + #print(diff_indxs.shape, precomputed_indxs.shape) + state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) @@ -312,7 +392,7 @@ def get_atom_positions(mol): state.main_grid_AO=main_grid_AO[:1, 0] state.sparse_diffs_grid_AO = sparse_diffs_grid_AO - state.diffs_grid_AO = diffs_grid_AO + #state.diffs_grid_AO = diffs_grid_AO # this isn't used for energy eval if pad_sparse_diff_grid != -1: max_pad_sparse_diff_grid = state.rows.shape[0] @@ -328,7 +408,10 @@ def get_atom_positions(mol): #state.grid_AO = state.grid_AO[:1] state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[:1] + state.nonzero_indices = np.expand_dims(state.nonzero_indices, axis=0) + + # todo: looks like we're padding, then looking for zeros, then padding; this can be simplified. if pad_distinct_ERIs != -1: max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] @@ -346,14 +429,14 @@ def get_atom_positions(mol): assert state.grid_AO.shape[2] == state.grid_weights.shape[1] assert state.grid_AO.shape[2] == state.grid_coords.shape[1] assert state.grid_AO.shape[2] == state.main_grid_AO.shape[1] - assert state.grid_AO.shape[2] == state.diffs_grid_AO.shape[2] + #assert state.grid_AO.shape[2] == state.diffs_grid_AO.shape[2] pad = pad_grid_AO - state.grid_AO.shape[2] assert pad > 0, (pad_grid_AO, state.grid_AO.shape[2]) state.grid_AO = np.pad(state.grid_AO, ((0,0),(0,0), (0,pad), (0,0))) state.grid_weights = np.pad(state.grid_weights, ((0,0),(0,pad))) state.grid_coords = np.pad(state.grid_coords, ((0,0),(0,pad),(0,0))) state.main_grid_AO = np.pad(state.main_grid_AO, ((0,0),(0,pad),(0,0))) - state.diffs_grid_AO = np.pad(state.diffs_grid_AO, ((0,0),(0,0),(0,pad),(0,0))) + #state.diffs_grid_AO = np.pad(state.diffs_grid_AO, ((0,0),(0,0),(0,pad),(0,0))) #if opts.wandb: # wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2], @@ -362,7 +445,8 @@ def get_atom_positions(mol): # "pad_grid_AO_target": pad_grid_AO}) - indxs = np.abs(state.nonzero_distinct_ERI ) > 1e-9 + # todo: make this into a variable we can control from commandline. + indxs = np.abs(state.nonzero_distinct_ERI ) > opts.eri_threshold #1e-9 state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[indxs] state.nonzero_indices = state.nonzero_indices[indxs] remainder = state.nonzero_indices.shape[0] % batches @@ -374,6 +458,9 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) + precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) + #print(state.nonzero_indices.shape, precomputed_nonzero_indices.shape) + if pad_nonzero_distinct_ERI != -1: max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] @@ -383,11 +470,24 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) + state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) + #print(state.precomputed_nonzero_indices.shape, state.nonzero_indices.shape) + #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) state.pad_sizes = np.array([ max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid]) + + if opts.eri_f32: + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.astype(jnp.float32) + state.diffs_ERI = state.diffs_ERI.astype(jnp.float32) + + if opts.xc_f32: + state.main_grid_AO = state.main_grid_AO.astype(jnp.float32) + state.grid_AO = state.grid_AO.astype(jnp.float32) + state.sparse_diffs_grid_AO = state.sparse_diffs_grid_AO.astype(jnp.float32) + return state @@ -400,9 +500,15 @@ def nanoDFT(mol_str, opts): if opts.wandb: import wandb if opts.alanine: - wandb.init(project='ndft_alanine') + run = wandb.init(project='ndft_alanine') + elif opts.qm9: + run = wandb.init(project='ndft_qm9') else: - wandb.init(project='ndft') + run = wandb.init(project='ndft') + opts.name = run.name + + else: + opts.name = "%i"%time.time() rnd_key = jax.random.PRNGKey(42) n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ @@ -457,6 +563,11 @@ def nanoDFT(mol_str, opts): ) params = params.to_float32() + if opts.resume: + print("loading checkpoint") + params = pickle.load(open("checkpoints/%s_model.pickle"%opts.resume, "rb")) + print("done loading. ") + if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer @@ -466,11 +577,13 @@ def nanoDFT(mol_str, opts): # still differs on # [ ] weight initialization - def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi #return learning_rate * it / warmup_iters # to allow jax jit? # allow jax jit - '''if it < warmup_iters: return learning_rate * it / warmup_iters - if it > lr_decay_iters: return min_lr + '''if it < warmup_iters: return learning_rate * it / warmup_iters # linearly increase until hit warmup iters. + if it > lr_decay_iters: return min_lr # after decay (600k iterations) go to 10x lower + + # in between, decay learning rate using this function; this is from 2k steps to 600k steps decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) @@ -484,9 +597,6 @@ def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2 cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) return cond1 + cond2 + cond3 - - - adam = optax.chain( optax.clip_by_global_norm(1), optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), @@ -497,23 +607,29 @@ def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2 w = params + from torch.utils.data import DataLoader, Dataset class OnTheFlyQM9(Dataset): # prepares dft tensors with pyscf "on the fly". # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples - def __init__(self, opts, nao=294, train=True, num_epochs=10**9): + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): # only take molecules with use {CNOFH}, nao=nao and spin=0. df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules if nao != -1: df = df[df["nao"]==nao] - df = df.sample(frac=1).reset_index(drop=True) + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? if train: self.mol_strs = df["pyscf"].values[:-10] else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + self.num_epochs = num_epochs self.opts = opts self.validation = not train + self.extrapolate = extrapolate self.benzene = [ ["C", ( 0.0000, 0.0000, 0.0000)], @@ -539,7 +655,7 @@ def __init__(self, opts, nao=294, train=True, num_epochs=10**9): ] if opts.benzene: self.mol_strs = [self.benzene] - if opts.waters: self.mol_strs = [self.waters] + if opts.waters: self.mol_strs = [self.waters] if opts.alanine: self.mol_strs = mol_str if train: self.bs = opts.bs @@ -549,16 +665,26 @@ def __len__(self): return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): - return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, wiggle_num=0, do_pyscf=self.validation, validation=False, mol_idx=idx) - + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx) + + val_qm9 = OnTheFlyQM9(opts, train=False) + ext_qm9 = OnTheFlyQM9(opts, extrapolate=True) + + # parallel dataloader bug; precompute here is not slow but causes dataloader later to die. + # run once to quickly precompute. + if opts.precompute: + val_state = val_qm9[0] + ext_state = ext_qm9[0] + exit() + qm9 = OnTheFlyQM9(opts, train=True) - if opts.workers != 0: - train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) - else: - train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) + if opts.workers != 0: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) + else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) - val_qm9 = OnTheFlyQM9(opts, train=False) + else: states = [batched_state(mol_str[0], opts, opts.bs, do_pyscf=True)] + [batched_state(mol_str[i], opts, opts.bs, do_pyscf=False) for i in range(opts.states-1)] class DummyIterator: @@ -574,8 +700,15 @@ def __next__(self): return self.item vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) adam_state = adam.init(w) + + if opts.resume: + print("loading adam state") + adam_state = pickle.load(open("checkpoints/%s_adam_state.pickle"%opts.resume, "rb")) + print("done") + w, adam_state = jax.device_put(w), jax.device_put(adam_state) + @partial(jax.jit, backend=opts.backend) def update(w, adam_state, accumulated_grad): accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) @@ -587,7 +720,7 @@ def update(w, adam_state, accumulated_grad): if not opts.nn: total_params = -1 wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) - min_val, min_dm, mins, val_state, valid_str, step = 0, 0, np.ones(opts.bs)*1e6, None, "", 0 + min_val, min_dm, mins, valid_str, step, val_state, ext_state = 0, 0, np.ones(opts.bs)*1e6, "", 0, None, None t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 paddings = [] @@ -596,7 +729,7 @@ def update(w, adam_state, accumulated_grad): if iteration == 0: summary(state) state = jax.device_put(state) - # estimate max padding. + # Estimate max padding. if iteration < 100: paddings.append(state.pad_sizes.reshape(1, -1)) _paddings = np.concatenate(paddings, axis=0) @@ -604,19 +737,41 @@ def update(w, adam_state, accumulated_grad): dct = {} dct["iteraton"] = iteration + + states.append(state) + if len(states) > opts.mol_repeats: states.pop(0) + + load_time, t0 = time.time()-t0, time.time() + if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: + t0 = time.time() + try: + name = opts.name.replace("-", "_") + path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) + path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) + print("trying to checkpoint to %s and %s"%(path_model, path_adam)) + pickle.dump(jax.device_get(w), open(path_model, "wb")) + pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) + print("done!") + print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) + except: + print("fail!") + pass + print("tried saving model took %fs"%(time.time()-t0)) + save_time, t0 = time.time()-t0, time.time() + - states = states[-opts.grad_acc+1:] + [state] if len(states) < 50: print(len(states)) for j, state in enumerate(states): - print(".\t", end="", flush=True) + print(". ", end="", flush=True) if j == 0: _t0 =time.time() (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) print(",", end="", flush=True) if j == 0: time_step1 = time.time()-_t0 + # todo: have hyper parameter that accumulates gradient or takes step? w, adam_state = update(w, adam_state, grad) # todo: rename @@ -651,13 +806,22 @@ def update(w, adam_state, accumulated_grad): dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) + step = adam_state[1].count + plot_time, t0 = time.time()-t0, time.time() - if opts.nn and (iteration < 1000 or iteration % 10 == 0): + + + # TODO: Plot molecules and val/ext angles. + if opts.nn and (iteration < 250 or iteration % 10 == 0): if val_state is None: val_state = jax.device_put(val_qm9[0]) _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) - lr = custom_schedule(iteration) - valid_str = "lr=%.3e"%lr + "val_error=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + if ext_state is None: ext_state = jax.device_put(ext_qm9[0]) + _, (ext_vals, _, edensity_matrix, eW) = valf(w, ext_state, opts.normal, opts.nn) + + lr = custom_schedule(step) + valid_str = "lr=%.3e"%lr + "val=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + valid_str += "ext=" + "".join(["%.4f "%(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" if opts.wandb: for i in range(0, opts.val_bs): dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) @@ -665,10 +829,20 @@ def update(w, adam_state, accumulated_grad): dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) - dct["scheduled_lr"] = custom_schedule(iteration) - if opts.wandb: wandb.log(dct) - valid_time, t0 = time.time()-t0, time.time() + dct['ext_l%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) + dct['ext_E%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV) + dct['ext_pyscf%i'%i ] = np.abs(ext_state.pyscf_E[i]) + dct['img/ext_dm%i'%i] = wandb.Image(np.expand_dims(edensity_matrix[i], axis=-1)) + dct['img/ext_W%i'%i] = wandb.Image(np.expand_dims(eW[i], axis=-1)) + + dct["scheduled_lr"] = lr + + + if opts.wandb: + dct["step"] = step + wandb.log(dct) + val_time, t0 = time.time()-t0, time.time() val, density_matrix = min_val, min_dm @@ -711,6 +885,8 @@ class IterationState: pos: np.array ao_types: np.array pad_sizes: np.array + precomputed_nonzero_indices: np.array + precomputed_indxs: np.array from pyscf.data.elements import charge as elements_proton from pyscf.dft import gen_grid, radi @@ -979,6 +1155,8 @@ def e(x): return np.expand_dims(x, axis=0) N=e(mol.nao_nr()), mask=e(mask), pad_sizes=e(pad_sizes), + precomputed_nonzero_indices=np.zeros((1,1)), + precomputed_indxs=np.zeros((1,1)), ) @@ -992,17 +1170,17 @@ def summary(state): for field_name, field_def in state.__dataclass_fields__.items(): field_value = getattr(state, field_name) try: - print("%20s %20s %20s"%(field_name,getattr(field_value, 'shape', None), getattr(field_value, "nbytes", None)/10**9)) + print("%35s %24s %20s"%(field_name,getattr(field_value, 'shape', None), getattr(field_value, "nbytes", None)/10**9)) total += getattr(field_value, "nbytes", None)/10**9 except: try: - print("%20s %20s %20s"%(field_name,getattr(field_value[0], 'shape', None), getattr(field_value[0], "nbytes", None)/10**9)) + print("%35s %25s %20s"%(field_name,getattr(field_value[0], 'shape', None), getattr(field_value[0], "nbytes", None)/10**9)) total += getattr(field_value, "nbytes", None)/10**9 except: print("BROKE FOR ", field_name) - print("%20s %20s %20s"%("-", "total", total)) + print("%35s %25s %20s"%("-", "total", total)) try: print(state.pyscf_E[:, -1]) except: @@ -1125,11 +1303,12 @@ def pyscf_reference(mol_str, opts): pyscf_hlgaps = [] lumo = mol.nelectron//2 homo = lumo - 1 + t0 = time.time() def callback(envs): pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV pyscf_hlgaps.append(hl_gap_hartree) - print("PYSCF: ", pyscf_energies[-1] , "[eV]") + print("PYSCF: ", pyscf_energies[-1], "[eV]", time.time()-t0) mf.callback = callback mf.kernel() print("") @@ -1188,16 +1367,22 @@ def reference(mol_str, opts): parser.add_argument('-level', type=int, default=0) # GD options - parser.add_argument('-backend', type=str, default="cpu") - parser.add_argument('-lr', type=float, default=2.5e-4) - parser.add_argument('-steps', type=int, default=100000) - parser.add_argument('-bs', type=int, default=8) - parser.add_argument('-val_bs', type=int, default=8) - parser.add_argument('-grad_acc', type=int, default=16) + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-steps', type=int, default=100000) + parser.add_argument('-bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + + # energy computation speedups + parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) + parser.add_argument('-xc_f32', action="store_true") + parser.add_argument('-eri_f32', action="store_true") + parser.add_argument('-eri_bs', type=int, default=8) parser.add_argument('-normal', action="store_true") parser.add_argument('-wandb', action="store_true") - parser.add_argument('-prof', action="store_true") + parser.add_argument('-prof', action="store_true") parser.add_argument('-visualize', action="store_true") parser.add_argument('-skip', action="store_true", help="skip pyscf test case") @@ -1209,7 +1394,8 @@ def reference(mol_str, opts): parser.add_argument('-waters', action="store_true") parser.add_argument('-alanine', action="store_true") parser.add_argument('-states', type=int, default=1) - parser.add_argument('-workers', type=int, default=5) # set to 0 for faster pyscf precompute. + parser.add_argument('-workers', type=int, default=5) + parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. # do noise schedule, start small slowly increase parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") @@ -1223,6 +1409,9 @@ def reference(mol_str, opts): parser.add_argument('-medium', action="store_true") parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") + + parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model + parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model opts = parser.parse_args() if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True @@ -1272,7 +1461,7 @@ def reference(mol_str, opts): ["H", ( 0.451, 0.165, -0.083)]]] elif opts.alanine: - mol_strs = [[ + mol_strs = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). ["H", ( 2.000 , 1.000, -0.000)], ["C", ( 2.000 , 2.090, 0.000)], ["H", ( 1.486 , 2.454, 0.890)], diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index b2439da7..9cc2b9d0 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -111,26 +111,23 @@ def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp. # inspired by 3d point cloud transformers; # nspired by andrew: use trigonometric functions as feature transformations position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) - positions = linear(params.project_positions, position) # L x Dm*4 + positions = linear(params.project_positions, position) # L x Dm # Add (learned) positional encodings - #x = jnp.concatenate([embeddings[:, :-3], position], -1) - x = embeddings + positions - #x = embeddings - L, dm = x.shape - - def block(x, layer_num, layers): - # Layer-normalize embeddings - x = vmap(standardize)(x) - t1 = elementwise_linear(layer.norm_self_attn, x) # L x Dm - - L, Dm = t1.shape - nheads = cfg.heads - qkv = linear(layer.kqv, t1)#.reshape(L, Dm, 3) - #q, k, v = [qkv[:, :, i].reshape(nheads, L, Dm//nheads) for i in range(3)] - q = jnp.transpose(qkv[:, 0*Dm:1*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) - k = jnp.transpose(qkv[:, 1*Dm:2*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) - v = jnp.transpose(qkv[:, 2*Dm:3*Dm].reshape(L, nheads, Dm//nheads), (1, 0, 2)) + x = embeddings + positions # L x Dm + L, Dm = x.shape + nheads = cfg.heads + + def block(x, layer_num, layer): + # Layer-normalize + t1 = vmap(standardize)(x) # L x Dm + t1 = elementwise_linear(layer.norm_self_attn, t1) # L x Dm + + qkv = linear(layer.kqv, t1) + q,k,v = jnp.split(qkv, 3, axis=1) + q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + k = jnp.transpose(k.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) # do like graphformer and append position here? @@ -141,26 +138,28 @@ def block(x, layer_num, layers): attn = jax.nn.softmax(score , axis=1) x = x + (attn @ v).reshape(L, Dm) - # Layer-normalize embeddings - t2 = vmap(standardize)(embeddings) - t2 = elementwise_linear(layer.norm_ff, x) # L x Dm + # Layer-normalize + t2 = vmap(standardize)(x) + t2 = elementwise_linear(layer.norm_ff, t2) # L x Dm # Feedforward fully connected t2 = linear(layer.ffn1, t2) # L x Dm*4 t2 = jax.nn.gelu(t2) t2 = linear(layer.ffn2, t2) # L x Dm - # Add this layer's contribution into embeddings + # Residual connection x = x + t2 return x, score # Apply the transformer layers # todo: cut jit time by making this jax.lax.foriloop for layer_num, layer in enumerate(params.layers): - if layer_num % 2 == 0: x, score = jax.checkpoint(block)(x, layer_num, layer) - else: x, score = block(x, layer_num, layer) + x, score = jax.checkpoint(block)(x, layer_num, layer) - return score[0] # take first head + # todo: if this isn't symmetric eigh gives imaginary eigenvalues? (bad) + M = score[0] # take first attention head + #M = (M + M.T)/2 # make symmetric! + return M import types import json @@ -222,3 +221,120 @@ def convert_to_float32(x): return ParamsDict(**new_dict) self.__dict__ = jax.tree_map(convert_to_float32, self.__dict__) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-steps', type=int, default=100000) + parser.add_argument('-bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + + # energy computation speedups + parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) + parser.add_argument('-xc_f32', action="store_true") + parser.add_argument('-eri_f32', action="store_true") + parser.add_argument('-eri_bs', type=int, default=8) + + parser.add_argument('-normal', action="store_true") + parser.add_argument('-wandb', action="store_true") + parser.add_argument('-prof', action="store_true") + parser.add_argument('-visualize', action="store_true") + parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + + # dataset + parser.add_argument('-qm9', action="store_true") + parser.add_argument('-benzene', action="store_true") + parser.add_argument('-hydrogens', action="store_true") + parser.add_argument('-water', action="store_true") + parser.add_argument('-waters', action="store_true") + parser.add_argument('-alanine', action="store_true") + parser.add_argument('-states', type=int, default=1) + parser.add_argument('-workers', type=int, default=5) + parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase + parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") + parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") + parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + + # models + parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") + parser.add_argument('-tiny', action="store_true") + parser.add_argument('-small', action="store_true") + parser.add_argument('-base', action="store_true") + parser.add_argument('-medium', action="store_true") + parser.add_argument('-large', action="store_true") + parser.add_argument('-xlarge', action="store_true") + opts = parser.parse_args() + + # initialize model + # transformer tiny 5M + d_model= 192 + n_heads = 6 + n_layers = 12 + + from train import nao + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + + + # compute dummy output + from train import batched_state, summary + opts.alanine = True + alanine = [ ["H", ( 2.000 , 1.000, -0.000)], ["C", ( 2.000 , 2.090, 0.000)], ["H", ( 1.486 , 2.454, 0.890)], ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], ["O", ( 4.391 , 1.877, -0.000)], ["N", ( 3.555 , 3.970, -0.000)], ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], ["H", ( 5.408 , 4.316, 0.890)], ["C", ( 5.661 , 4.221, -1.232)], ["H", ( 5.123 , 4.521, -2.131)], + ["H", ( 6.630 , 4.719, -1.206)], ["H", ( 5.809 , 3.141, -1.241)], ["C", ( 4.713 , 6.129, 0.000)], ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], ["H", ( 6.737 , 6.359, -0.000)], ["C", ( 5.846 , 8.284, 0.000)], ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], ["H", ( 6.360 , 8.648, -0.890)], ] + state = batched_state(alanine, opts, opts.bs, \ + wiggle_num=0, do_pyscf=False, validation=False, \ + extrapolate=False, mol_idx=0) + summary(state) + + output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), + static_argnums=(0,), + backend="cpu")(cfg, \ + params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + + + print(np.sum(output)) # 162.58726108305348 + + + # store model + import pickle + pickle.dump(params, open("checkpoints/example.pickle", "wb")) + + # reload model + new_params = pickle.load(open("checkpoints/example.pickle", "rb")) + + # check that output remains the same + new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), + static_argnums=(0,), + backend="cpu")(cfg, \ + new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + + assert np.allclose(output, new_output) + print("TEST CASE PASSED!") + + + \ No newline at end of file From 9a016c6a65bb7d14e870a2e06c46be448fb8d333 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Sat, 6 Jan 2024 09:29:48 +0000 Subject: [PATCH 19/22] work. --- pyscf_ipu/direct/train.py | 201 ++++++++++++++++++++++---------- pyscf_ipu/direct/transformer.py | 71 ++++++----- 2 files changed, 186 insertions(+), 86 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 30773376..c94dc837 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -15,6 +15,8 @@ import math from functools import partial import pickle +import random +random.seed(42) cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 @@ -25,8 +27,8 @@ def T(x): return jnp.transpose(x, (0,2,1)) # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. def dm_energy(W: BxNxK, state, normal, nn): if nn: - W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ - W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0))(cfg, \ + W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) W = W.astype(jnp.float64) # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! @@ -54,6 +56,7 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): if False: main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) + # todo: remove state.grid_AO w/ sparsity tricks => reduce memory 10x. rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) rho = rho_a - rho_b @@ -122,6 +125,8 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_sparse_diff_grid=200000, mol_idx=42, ): + + do_print = False # pad molecule if using nn. if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ @@ -145,10 +150,23 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI=400000 pad_sparse_diff_grid=400000''' #padding_estimate = [37426, 149710, 17010, 140122, 138369] - padding_estimate = [48330, 163222, 17034, 159361, 139505] - padding_estimate = [int(a*1.1) for a in padding_estimate] + + # idea: + # train w/o atom pertubation until convergence, then 1, then 2, then 3. + + padding_estimate = [48330, 163222, 17034, 159361, 139505] + + if opts.nperturb == 2 or opts.nperturb == 1: padding_estimate = [int(a*2.1) for a in padding_estimate] + else: padding_estimate = [int(a*1.1) for a in padding_estimate] + + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + if opts.basis == "def2-svp": + # disable padding temporarily + max_pad_electrons, max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid = \ + -1, -1, -1, -1, -1, -1 + if opts.alanine: # todo: (adam) the ERI padding may change when rotating molecule more! pad_electrons = 70 @@ -178,9 +196,10 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) if opts.qm9: - atoms = np.array([mol_str[i][1] for i in range(0,3)]) - # pick random atom to permute (of the first 9 heavy ones) - atom_num = int(np.random.uniform(0, 8)) + # pick random atom to perturb (of the first 9 heavy ones) + atom_num1, atom_num2, atom_num3 = random.sample(range(9), 3) + atoms = np.array([mol_str[i][1] for i in range(0,10)]) + atom_type = [mol_str[i][0] for i in range(0,10)] if opts.alanine: # train on [-180, 180], validate [-180, 180] extrapolate [-360, 360]\[180, -180] @@ -192,6 +211,28 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, angles = [] + # Combinatorics of atom substitutions (n_electrons has to be even). + # P := single atom modification, n_electrons remain even (C-> O, O->C, N->F, F->N) + # B1, B2 := single atom modification, n_electrons becomes odd (C->{N,F}, O->{N,F}, N->{O,C}, F->{O,C}) + I = {"C":"C", "O":"O", "F":"F", "N":"N"} + P = {"C": "O", "O":"C", "F":"N", "N":"F"} + B1 = {"C":"N", "O":"N", "F":"O", "N":"O"} + B2 = {"C":"F", "O":"F", "F":"C", "N":"C"} + + if opts.nperturb == 3: + allowed_pertubations=[ + (I,I,I), (I,I,P), (I,P,I), (I,P,P), (P,I,I), (P,I,P), (P,P,I), (P,P,P), + (I,B1,B2), (B1,I,B2), (B1,B2,I), (I,B2,B1), (B2,I,B1), (B2,B1,I), (P,B1,B2), (B1,P,B2), (B1,B2,P), (P,B2,B1), (B2,P,B1), (B2,B1,P), + (I,B1,B1), (B1,I,B1), (B1,B1,I), (I,B1,B1), (B1,I,B1), (B1,B1,I), (P,B1,B1), (B1,P,B1), (B1,B1,P), (P,B1,B1), (B1,P,B1), (B1,B1,P), + (I,B2,B2), (B2,I,B2), (B2,B2,I), (I,B2,B2), (B2,I,B2), (B2,B2,I), (P,B2,B2), (B2,P,B2), (B2,B2,P), (P,B2,B2), (B2,P,B2), (B2,B2,P), + ] + if opts.nperturb == 2: + allowed_pertubations=[ (I, I, I), (I, I, P), (I, P, I), (I, P, P), (I, B1, B1), (I, B1, B2), (I, B2, B1), (I, B2, B2), ] + if opts.nperturb == 1: + allowed_pertubations=[ (I, I, I), (I, I, P), ]*10 + + if opts.nperturb: random.shuffle(allowed_pertubations) + states = [] for iteration in range(bs): @@ -253,50 +294,43 @@ def get_atom_positions(mol): wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' elif opts.qm9: - # todo: find dihedral to rotate over similar to alanine dipeptide. - # broken; rotate first three atoms around their center of mass - # this breaks molecule; should use dihedral angle as done with the dipeptide. - #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] - #center = atoms.mean(axis=0) - #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center + if iteration == 0 and (validation or extrapolate): pass + else: + s = mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] - # for extrapolation, do even more. + # for small stuff, wiggle a single atom a bit. + if opts.nperturb <= 1: mol_str[atom_num3][1] = tuple(atoms[atom_num3] + np.random.normal(0, opts.wiggle_var, (3))) - if iteration == 0 and (validation or extrapolate): - pass - else: - #mol_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) - mol_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0, opts.wiggle_var, (3))) - #mol_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) - #mol_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) + # assuming all atoms are non-hydrogen; if wrong don't do any pertubations. + if opts.nperturb > 0 and mol_str[atom_num1][0] != "H" and mol_str[atom_num2][0] != "H" and mol_str[atom_num3][0] != "H": + A,B,C = allowed_pertubations[iteration] + mol_str[atom_num1][0] = A[atom_type[atom_num1]] + mol_str[atom_num2][0] = B[atom_type[atom_num2]] + mol_str[atom_num3][0] = C[atom_type[atom_num3]] - '''if opts.wandb and iteration == 0: - from plot import create_rdkit_mol - import wandb - str = [mol_str[j][0] for j in range(len(mol_str))] - pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) - wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' + s += "->" + mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] + if do_print: print(iteration, s) + + # -nperturb 2 is a lot slower; changing self.prune=False roughly doubles AO matrices. if iteration == 0: state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting state = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) + states.append(state) - # If we add energy here we get plot basically! - # todo: save and store in training loop, then we can match with energy - # can't get to work in wandb, but can just use download api and the plot. - '''if opts.alanine and opts.wandb: - for phi, psi in angles: - if not validation: - wandb.log({"phi_train": phi , "psi_train": psi}) - else: - wandb.log({"phi_valid": phi, "psi_valid": psi})''' + if do_print: print("cat states") + state = cats(states) N = state.N[0] + if do_print: print("get sparsity") + + #_nonzero = None + # Compute ERI sparsity. nonzero = [] for e, i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): @@ -305,19 +339,26 @@ def get_atom_positions(mol): e[indxs] = 0 nonzero.append(np.nonzero(e)[0]) + #_nonzero = indxs if _nonzero is None else np.logical_or(_nonzero, indxs) + + # Merge nonzero indices and prepare (ij, kl). # rep is the number of repetitions we include in the sparse representation. - #nonzero_indices = np.union1d(nonzero[0], nonzero[1]) + if do_print: print("union") # bottleneck for def2-svp. todo: fix above logical_or trick to do faster. union = nonzero[0] - for i in range(1, len(nonzero)): + for i in range(1, len(nonzero)): # this takes 12s/it for def2-svp. union = np.union1d(union, nonzero[i]) nonzero_indices = union + #exit() + + from sparse_symmetric_ERI import get_i_j, num_repetitions_fast ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) batches = opts.eri_bs es = [] + if do_print: print("pad") for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): nonzero_distinct_ERI = e[nonzero_indices] / rep remainder = nonzero_indices.shape[0] % (batches) @@ -328,6 +369,7 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) + if do_print: print("pad ijkl indices") i, j = get_i_j(ij.reshape(-1)) k, l = get_i_j(kl.reshape(-1)) @@ -342,6 +384,7 @@ def get_atom_positions(mol): # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. # function below does this. # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). + if do_print: print("precompute indices. ") from sparse_symmetric_ERI import precompute_indices if opts.normal: diff_state = None @@ -507,6 +550,8 @@ def nanoDFT(mol_str, opts): run = wandb.init(project='ndft') opts.name = run.name + wandb.log(vars(opts)) + else: opts.name = "%i"%time.time() @@ -578,20 +623,8 @@ def nanoDFT(mol_str, opts): # [ ] weight initialization def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi - #return learning_rate * it / warmup_iters # to allow jax jit? - # allow jax jit - '''if it < warmup_iters: return learning_rate * it / warmup_iters # linearly increase until hit warmup iters. - if it > lr_decay_iters: return min_lr # after decay (600k iterations) go to 10x lower - - # in between, decay learning rate using this function; this is from 2k steps to 600k steps - decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) - assert 0 <= decay_ratio <= 1 - coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) - return min_lr + coeff * (learning_rate - min_lr)''' - #if it < warmup_iters: return learning_rate * it / warmup_iters cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters cond2 = (it > lr_decay_iters) * min_lr - decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) @@ -600,7 +633,7 @@ def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2 adam = optax.chain( optax.clip_by_global_norm(1), optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), - optax.add_decayed_weights(0.1),#, configure_decay_mask(params)), + optax.add_decayed_weights(0.1), optax.scale_by_schedule(custom_schedule), optax.scale(-1), ) @@ -615,9 +648,18 @@ class OnTheFlyQM9(Dataset): # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): - # only take molecules with use {CNOFH}, nao=nao and spin=0. - df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules - if nao != -1: df = df[df["nao"]==nao] + + if opts.qh9: + df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") + if nao != -1: df = df[df["N_sto3g"]==55] + print(df.shape) + else: + # only take molecules with use {CNOFH}, nao=nao and spin=0. + #qm9 + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. # this shuffling is important, because it makes the last 10 samples iid (used for validation) #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? @@ -684,6 +726,17 @@ def __getitem__(self, idx): else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) + if opts.test_dataloader: + + t0 = time.time() + for iteration, state in enumerate(pbar): + if iteration == 0: summary(state) + print(time.time()-t0) + t0 = time.time() + print(state.pad_sizes.reshape(1, -1)) + + exit() + else: states = [batched_state(mol_str[0], opts, opts.bs, do_pyscf=True)] + [batched_state(mol_str[i], opts, opts.bs, do_pyscf=False) for i in range(opts.states-1)] @@ -711,7 +764,8 @@ def __next__(self): return self.item @partial(jax.jit, backend=opts.backend) def update(w, adam_state, accumulated_grad): - accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) + if opts.grad_acc: accumulated_grad = jax.tree_map(lambda x: x / (opts.bs * opts.mol_repeats), accumulated_grad) + else: accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) updates, adam_state = adam.update(accumulated_grad, adam_state, w) w = optax.apply_updates(w, updates) return w, adam_state @@ -722,9 +776,11 @@ def update(w, adam_state, accumulated_grad): min_val, min_dm, mins, valid_str, step, val_state, ext_state = 0, 0, np.ones(opts.bs)*1e6, "", 0, None, None t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 + accumulated_grad = None paddings = [] states = [] + for iteration, state in enumerate(pbar): if iteration == 0: summary(state) state = jax.device_put(state) @@ -741,6 +797,7 @@ def update(w, adam_state, accumulated_grad): states.append(state) if len(states) > opts.mol_repeats: states.pop(0) + if opts.shuffle: random.shuffle(states) # load_time, t0 = time.time()-t0, time.time() if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: @@ -761,7 +818,6 @@ def update(w, adam_state, accumulated_grad): save_time, t0 = time.time()-t0, time.time() - if len(states) < 50: print(len(states)) for j, state in enumerate(states): @@ -771,8 +827,15 @@ def update(w, adam_state, accumulated_grad): print(",", end="", flush=True) if j == 0: time_step1 = time.time()-_t0 - # todo: have hyper parameter that accumulates gradient or takes step? - w, adam_state = update(w, adam_state, grad) + if opts.grad_acc == 0 or len(states) < opts.mol_repeats: + w, adam_state = update(w, adam_state, grad) + else: + accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) + if j % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. + w, adam_state = update(w, adam_state, grad) + accumulated_grad = None + print("#", end="", flush=True) + # todo: rename global_batch_size = len(states)*opts.bs @@ -983,7 +1046,6 @@ def get_partition( coords_all = [] weights_all = [] - # [ ] consider another grid? for ia in range(mol.natm): coords, vol = atom_grids_tab[mol.atom_symbol(ia)] coords = coords + atom_coords[ia] # [ngrid, 3] @@ -1008,7 +1070,9 @@ def build(self, atom_coords, state=None) : mol = self.mol atom_grids_tab = self.gen_atomic_grids( - mol, self.atom_grid, self.radi_method, self.level, self.prune + mol, self.atom_grid, self.radi_method, self.level, + self.prune, + #False, # WARNING: disabling self.prune; this makes sizes of C,N,O,F all a bit larger, but the same ; allow atom substitution ) coords, weights = get_partition( @@ -1038,6 +1102,7 @@ def grids_from_pyscf_mol( def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): + do_print = False #t0 = time.time() mol = build_mol(mol_str, opts.basis) if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) @@ -1048,6 +1113,8 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. from pyscf import dft + if do_print: print("grid", end="", flush=True) + #grids = pyscf.dft.gen_grid.Grids(mol) grids = DifferentiableGrids(mol) grids.level = opts.level @@ -1059,6 +1126,8 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. + if do_print: print("int1e", end="", flush=True) + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. kinetic = mol.intor_symmetric('int1e_kin') # (N,N) nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) @@ -1083,10 +1152,14 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T eri_threshold = 0 batches = 1 nipu = 1 + + # todo: rewrite int2e_sph to only recompute changing atomic orbitals (will be N times faster). + if do_print: print("int2e",end ="", flush=True) nonzero_distinct_ERI = mol.intor("int2e_sph", aosym="s8") #ERI = [nonzero_distinct_ERI, nonzero_indices] #ERI = ERI ERI = np.zeros(1) + if do_print: print(nonzero_distinct_ERI.shape, nonzero_distinct_ERI.nbytes/10**9) #ERI = mol.intor("int2e_sph") def e(x): return np.expand_dims(x, axis=0) @@ -1373,6 +1446,8 @@ def reference(mol_str, opts): parser.add_argument('-bs', type=int, default=8) parser.add_argument('-val_bs', type=int, default=8) parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + parser.add_argument('-grad_acc', type=int, default=0) # integer, deciding how many steps to accumulate. + parser.add_argument('-shuffle', action="store_true") # whether to to shuffle the window of states each step. # energy computation speedups parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) @@ -1387,7 +1462,9 @@ def reference(mol_str, opts): parser.add_argument('-skip', action="store_true", help="skip pyscf test case") # dataset + parser.add_argument('-nperturb', type=int, default=1, help="how many atoms to perturb (supports 1,2,3)") parser.add_argument('-qm9', action="store_true") + parser.add_argument('-qh9', action="store_true") parser.add_argument('-benzene', action="store_true") parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") @@ -1400,6 +1477,7 @@ def reference(mol_str, opts): parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + parser.add_argument('-test_dataloader', action="store_true", help="no training, just test/loop through dataloader. ") # models parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") @@ -1415,9 +1493,12 @@ def reference(mol_str, opts): opts = parser.parse_args() if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + assert opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." + args_dict = vars(opts) print(args_dict) + if opts.qm9: df = pd.read_pickle("alchemy/atom_9.pickle") df = df[df["spin"] == 0] # only consider spin=0 diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index 9cc2b9d0..bc17affc 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -58,7 +58,7 @@ def transformer_init( total_params += np.prod(params.embeddings.shape) print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) - rng, params.project_positions, shape = linear_init_uniform(rng, 12, d_model) + rng, params.project_positions, shape = linear_init_uniform(rng, 123, d_model) total_params += np.prod(shape) print("%26s %26s %26s"%("params.project_positions",shape, np.prod(shape))) @@ -95,23 +95,32 @@ def transformer_init( @partial(jax.jit, static_argnums=0) -def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray): +def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray, L_inv): """ cfg: Config, from transformer_init, holds hyperparameters params: Current transformer parameters, initialized in init x: 1D array of L integers, representing the input sequence output: L x n_vocab logits """ - L, = x.shape # x is just 1D. Vmap/pmap will handle batching - embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm - - all_pairs = jnp.linalg.norm(position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) - - # inspired by 3d point cloud transformers; - # nspired by andrew: use trigonometric functions as feature transformations - position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) + L, Dm = embeddings.shape + + # Roughly get f( {R@ri+t}_i ) = f( {r_i}_i ) + position = position - jnp.mean(position, axis=0).reshape(1, 3) # makes jnp.mean(position, axis=0) = [0,0,0] + cov = jnp.cov(position.T) + eigvects = jnp.linalg.eigh(cov)[1] + position = position @ eigvects # makes jnp.cov(positions.T)=jnp.eye(3) + + # Mix of sin/cos and 3d point cloud transformers. + #position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) + position = jnp.concatenate([position] + \ + [jnp.cos(position*f/20*2*np.pi) for f in range(20)] + \ + [jnp.sin(position*f/20*2*np.pi) for f in range(20)], + axis=1) #(N,3) -> (N,3+60+60) = (N, 123) positions = linear(params.project_positions, position) # L x Dm + del position + all_pairs = jnp.linalg.norm(positions.reshape(1, -1, Dm) - positions.reshape(-1, 1, Dm), axis=-1) + all_pairs = all_pairs / jnp.max(all_pairs) # Add (learned) positional encodings x = embeddings + positions # L x Dm @@ -128,12 +137,13 @@ def block(x, layer_num, layer): q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) k = jnp.transpose(k.reshape(L, nheads, Dm//nheads), (1, 0, 2)) v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) + score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm//nheads) - # do like graphformer and append position here? - #if layer_num < 6: # doesn't look like it helps - # score += H_core - # score += all_pairs + if True: # todo: why does this improve loss from ~1000 to ~300 first step (qm9). + score += H_core + #score += all_pairs # => NaNs for some reason + #score += L_inv + score += L_inv @ H_core @ L_inv.T attn = jax.nn.softmax(score , axis=1) x = x + (attn @ v).reshape(L, Dm) @@ -149,16 +159,25 @@ def block(x, layer_num, layer): # Residual connection x = x + t2 - return x, score + return x # Apply the transformer layers # todo: cut jit time by making this jax.lax.foriloop - for layer_num, layer in enumerate(params.layers): - x, score = jax.checkpoint(block)(x, layer_num, layer) + for layer_num, layer in enumerate(params.layers[:-1]): + x = jax.checkpoint(block)(x, layer_num, layer) + + layer = params.layers[-1] + # Prediction is last attention (without nhead = 1), and q=k so score is symmetric! + nheads = 1 + t1 = vmap(standardize)(x) # L x Dm + t1 = elementwise_linear(layer.norm_self_attn, t1) # L x Dm + qkv = linear(layer.kqv, t1) + q,k,v = jnp.split(qkv, 3, axis=1) + q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + #v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm*nheads) # symmetric: initial loss goes from 1200 to 980 (qm9). - # todo: if this isn't symmetric eigh gives imaginary eigenvalues? (bad) - M = score[0] # take first attention head - #M = (M + M.T)/2 # make symmetric! + M = score[0] return M import types @@ -274,7 +293,7 @@ def convert_to_float32(x): parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") opts = parser.parse_args() - + # initialize model # transformer tiny 5M d_model= 192 @@ -311,10 +330,10 @@ def convert_to_float32(x): extrapolate=False, mol_idx=0) summary(state) - output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), + output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0)), static_argnums=(0,), backend="cpu")(cfg, \ - params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) print(np.sum(output)) # 162.58726108305348 @@ -328,10 +347,10 @@ def convert_to_float32(x): new_params = pickle.load(open("checkpoints/example.pickle", "rb")) # check that output remains the same - new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), + new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0)), static_argnums=(0,), backend="cpu")(cfg, \ - new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) assert np.allclose(output, new_output) print("TEST CASE PASSED!") From 0a4abc959b1ebff41468a976cd53a7b5a8894ef8 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Sat, 6 Jan 2024 10:28:15 +0000 Subject: [PATCH 20/22] fixed k=q mistake --- pyscf_ipu/direct/transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index bc17affc..45e0e050 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -174,6 +174,7 @@ def block(x, layer_num, layer): qkv = linear(layer.kqv, t1) q,k,v = jnp.split(qkv, 3, axis=1) q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + k = q #v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm*nheads) # symmetric: initial loss goes from 1200 to 980 (qm9). From 458bee1ccf306cd110e1c97b832336d735085508 Mon Sep 17 00:00:00 2001 From: Alexander Mathiasen Date: Sun, 14 Jan 2024 15:48:46 +0000 Subject: [PATCH 21/22] . --- pyscf_ipu/direct/train.py | 566 +++++++++++++++++++++++--------------- 1 file changed, 349 insertions(+), 217 deletions(-) diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index c94dc837..18d1c5b8 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -1,8 +1,9 @@ import os -os.environ['OMP_NUM_THREADS'] = '8' +os.environ['OMP_NUM_THREADS'] = '16' import jax jax.config.update('jax_enable_x64', True) import jax.numpy as jnp +import scipy import numpy as np import pyscf import optax @@ -18,6 +19,7 @@ import random random.seed(42) +MD17_WATER, MD17_ETHANOL, MD17_ALDEHYDE, MD17_URACIL = 1, 2, 3, 4 cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 def T(x): return jnp.transpose(x, (0,2,1)) @@ -25,20 +27,24 @@ def T(x): return jnp.transpose(x, (0,2,1)) B, BxNxN, BxNxK = None, None, None # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. -def dm_energy(W: BxNxK, state, normal, nn): +def dm_energy(W: BxNxK, state, normal, nn, cfg=None, opts=None): if nn: W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0))(cfg, \ W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) + #W, state.ao_types, state.pos.astype(jnp.float64), state.H_core.astype(jnp.float64), state.L_inv.astype(jnp.float64)) W = W.astype(jnp.float64) # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! - L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + H = state.H_core + W + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ H @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal, opts.xc_f32) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO - diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32, opts.bs) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) energy: float = jnp.sum(energies) - return energy, (energies, E_xc, density_matrix, W) + return energy, (energies, E_xc, density_matrix, W, H) + + def sparse_mult(values, dm, state, gsize): in_ = dm.take(state.cols, axis=0) @@ -70,7 +76,7 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) return E_xc -def JK(density_matrix, state, normal, jax_foriloop, eri_f32): +def JK(density_matrix, state, normal, jax_foriloop, eri_f32, bs): if normal: J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) @@ -80,32 +86,43 @@ def JK(density_matrix, state, normal, jax_foriloop, eri_f32): if eri_f32: density_matrix = density_matrix.astype(jnp.float32) - '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.nonzero_indices[0], - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.indxs, - density_matrix, - jax_foriloop - )''' - - diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.precomputed_nonzero_indices, - density_matrix, - jax_foriloop - ) - correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.precomputed_indxs, - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - correction + if bs == 1: + diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + + + else: + '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.indxs, + density_matrix, + jax_foriloop + )''' + + diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.precomputed_nonzero_indices, + density_matrix, + jax_foriloop + ) + if bs > 1: + correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.precomputed_indxs, + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - correction return diff_JK.astype(jnp.float64) @@ -125,8 +142,9 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_sparse_diff_grid=200000, mol_idx=42, ): - - do_print = False + start_time = time.time() + do_print = opts.do_print + if do_print: print("\t[%.4fs] start of 'batched_state'. "%(time.time()-start_time)) # pad molecule if using nn. if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ @@ -142,24 +160,12 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 - if opts.qm9: + if opts.qm9 or opts.qh9: pad_electrons=60 - '''pad_diff_ERIs=120000 - pad_distinct_ERIs=400000 - pad_grid_AO=50000 - pad_nonzero_distinct_ERI=400000 - pad_sparse_diff_grid=400000''' - #padding_estimate = [37426, 149710, 17010, 140122, 138369] - - # idea: - # train w/o atom pertubation until convergence, then 1, then 2, then 3. - padding_estimate = [48330, 163222, 17034, 159361, 139505] - if opts.nperturb == 2 or opts.nperturb == 1: padding_estimate = [int(a*2.1) for a in padding_estimate] else: padding_estimate = [int(a*1.1) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate if opts.basis == "def2-svp": @@ -185,7 +191,39 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] - mol = build_mol(mol_str, opts.basis) + if opts.md17 > 0: + if opts.md17 == MD17_WATER: + if opts.level == 1: padding_estimate = [ 3361, 5024, 10172 , 5024 , 155958] + if opts.level == 3: padding_estimate = [ 3361, 5024, 34310 , 5024 , 494370] + if opts.bs == 2 and opts.wiggle_var == 0: padding_estimate = [ 1, 5024, 10172 , 5024 , 1] + padding_estimate = [int(a*1.5) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + elif opts.md17 == MD17_ETHANOL: + pad_electrons = 72 + #padding_estimate = [ 99042, 197660, 34310*5 , 197308 , 494370*5] + #padding_estimate = [113522, 224275, 30348, 197308, 609163] + if opts.level == 1: padding_estimate = [ 1, 415186, 30348, 415145, 1] + padding_estimate = [int(a*1.1) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + elif opts.md17 == MD17_ALDEHYDE: + pad_electrons = 90 + #padding_estimate = [130646, 224386 , 30348, 223233, 626063] + #padding_estimate = [ 34074, 204235 , 30348, 203632, 285934] + if opts.level == 1: padding_estimate = [1, 939479, 35704, 939479, 1] + padding_estimate = [int(a*1.1) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + elif opts.md17 == MD17_URACIL: + #raise Exception("not ready yet") + pad_electrons = 132 + padding_estimate = [1, 3769764 , 51184, 3769764, 1] + padding_estimate = [int(a*1.05) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + + mol = build_mol(mol_str, opts.basis, unit="bohr") pad_electrons = min(pad_electrons, mol.nao_nr()) # Set seed to ensure different rotation; initially all workers did same rotation! @@ -195,7 +233,13 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) - if opts.qm9: + + if opts.md17 > 0: + natm = len(mol_str) + atom_num = random.sample(range(natm), 1)[0] + atoms = np.array([mol_str[i][1] for i in range(0,natm)]) + + if opts.qm9 or opts.qh9: # pick random atom to perturb (of the first 9 heavy ones) atom_num1, atom_num2, atom_num3 = random.sample(range(9), 3) atoms = np.array([mol_str[i][1] for i in range(0,10)]) @@ -235,7 +279,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, states = [] for iteration in range(bs): - + if do_print: print("\t[%.4fs] initializing state %i. "%(time.time()-start_time, iteration)) if opts.alanine: from rdkit import Chem from rdkit.Chem import AllChem @@ -277,6 +321,10 @@ def get_atom_positions(mol): import wandb wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) })''' + if opts.md17 > 0: + if iteration > 0: + mol_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0,opts.wiggle_var, (3))) + if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] center = water2_xyz.mean(axis=0) @@ -293,7 +341,7 @@ def get_atom_positions(mol): pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' - elif opts.qm9: + elif opts.qm9 or opts.qh9: if iteration == 0 and (validation or extrapolate): pass else: s = mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] @@ -318,16 +366,14 @@ def get_atom_positions(mol): c, w = state.grid_coords, state.grid_weights elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting state = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) - states.append(state) - if do_print: print("cat states") state = cats(states) N = state.N[0] + if do_print: print("\t[%.4fs] concatenated states. "%(time.time()-start_time)) - if do_print: print("get sparsity") #_nonzero = None @@ -341,24 +387,24 @@ def get_atom_positions(mol): #_nonzero = indxs if _nonzero is None else np.logical_or(_nonzero, indxs) + if do_print: print("\t[%.4fs] got sparsity. "%(time.time()-start_time)) # Merge nonzero indices and prepare (ij, kl). # rep is the number of repetitions we include in the sparse representation. - if do_print: print("union") # bottleneck for def2-svp. todo: fix above logical_or trick to do faster. union = nonzero[0] for i in range(1, len(nonzero)): # this takes 12s/it for def2-svp. union = np.union1d(union, nonzero[i]) nonzero_indices = union - #exit() + if do_print: print("\t[%.4fs] got union of sparsity. "%(time.time()-start_time)) from sparse_symmetric_ERI import get_i_j, num_repetitions_fast ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) + if do_print: print("\t[%.4fs] got (ij) and reps. "%(time.time()-start_time)) batches = opts.eri_bs es = [] - if do_print: print("pad") for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): nonzero_distinct_ERI = e[nonzero_indices] / rep remainder = nonzero_indices.shape[0] % (batches) @@ -369,9 +415,10 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) - if do_print: print("pad ijkl indices") + if do_print: print("\t[%.4fs] padded ERI and nonzero_indices. . "%(time.time()-start_time)) i, j = get_i_j(ij.reshape(-1)) k, l = get_i_j(kl.reshape(-1)) + if do_print: print("\t[%.4fs] got ijkl. "%(time.time()-start_time)) if remainder != 0: i = np.pad(i, ((0,batches-remainder))) @@ -380,13 +427,15 @@ def get_atom_positions(mol): l = np.pad(l, ((0,batches-remainder))) nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int32) # todo: use int16 or int32 here? state.nonzero_indices = nonzero_indices + if do_print: print("\t[%.4fs] padded and vstacked ijkl. "%(time.time()-start_time)) # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. # function below does this. # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). - if do_print: print("precompute indices. ") from sparse_symmetric_ERI import precompute_indices + + if opts.normal: diff_state = None else: main_grid_AO = state.grid_AO[:1] @@ -395,39 +444,41 @@ def get_atom_positions(mol): sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] # use the same sparsity pattern across a batch. - diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI - diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) - nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 + if opts.bs > 1: + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 - diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) - diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) + diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) + diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) - remainder = np.sum(nzr) % batches - if remainder != 0: - diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) - diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) + remainder = np.sum(nzr) % batches + if remainder != 0: + diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) + diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) - diff_ERIs = diff_ERIs.reshape(bs, batches, -1) - diff_indxs = diff_indxs.reshape(batches, -1, 4) + diff_ERIs = diff_ERIs.reshape(bs, batches, -1) + diff_indxs = diff_indxs.reshape(batches, -1, 4) - precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) + if opts.bs > 1: precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) - if pad_diff_ERIs == -1: - state.indxs=diff_indxs - state.diffs_ERI=diff_ERIs - assert False, "deal with precomputed_indxs; only added in else branch below" - else: - max_pad_diff_ERIs = diff_ERIs.shape[2] - # pad ERIs with 0 and indices with -1 so they point to 0. - assert diff_indxs.shape[1] == diff_ERIs.shape[2] - pad = pad_diff_ERIs - diff_indxs.shape[1] - assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) - state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) - state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros - #print(diff_indxs.shape, precomputed_indxs.shape) - state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) - - #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + if pad_diff_ERIs == -1: + state.indxs=diff_indxs + state.diffs_ERI=diff_ERIs + assert False, "deal with precomputed_indxs; only added in else branch below" + else: + max_pad_diff_ERIs = diff_ERIs.shape[2] + if do_print: print("\t[%.4fs] max_pad_diff_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) + # pad ERIs with 0 and indices with -1 so they point to 0. + assert diff_indxs.shape[1] == diff_ERIs.shape[2] + pad = pad_diff_ERIs - diff_indxs.shape[1] + assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) + state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) + state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + #print(diff_indxs.shape, precomputed_indxs.shape) + if opts.bs > 1: state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + + #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) state.rows=rows state.cols=cols @@ -439,6 +490,7 @@ def get_atom_positions(mol): if pad_sparse_diff_grid != -1: max_pad_sparse_diff_grid = state.rows.shape[0] + if do_print: print("\t[%.4fs] max_pad_sparse_diff_grid=%i"%(time.time()-start_time, max_pad_sparse_diff_grid)) assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] pad = pad_sparse_diff_grid - state.rows.shape[0] @@ -457,6 +509,7 @@ def get_atom_positions(mol): # todo: looks like we're padding, then looking for zeros, then padding; this can be simplified. if pad_distinct_ERIs != -1: max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] + if do_print: print("\t[%.4fs] max_pad_distinct_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) @@ -467,6 +520,7 @@ def get_atom_positions(mol): if pad_grid_AO != -1: max_pad_grid_AO = state.grid_AO.shape[2] + if do_print: print("\t[%.4fs] max_pad_grid_AO=%i"%(time.time()-start_time, max_pad_grid_AO)) prev_size = state.grid_AO.shape[2] assert state.grid_AO.shape[2] == state.grid_weights.shape[1] @@ -501,11 +555,12 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) - precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) + if opts.bs > 1: precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) #print(state.nonzero_indices.shape, precomputed_nonzero_indices.shape) if pad_nonzero_distinct_ERI != -1: max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] + if do_print: print("\t[%.4fs] max_pad_nonzero_distinct_ERI=%i"%(time.time()-start_time, max_pad_nonzero_distinct_ERI)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] @@ -513,7 +568,7 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) - state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) + if opts.bs > 1: state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) #print(state.precomputed_nonzero_indices.shape, state.nonzero_indices.shape) #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) @@ -535,6 +590,7 @@ def get_atom_positions(mol): def nanoDFT(mol_str, opts): + start_time = time.time() print() # Initialize validation set. # This consists of DFT tensors initialized with PySCF/CPU. @@ -546,6 +602,8 @@ def nanoDFT(mol_str, opts): run = wandb.init(project='ndft_alanine') elif opts.qm9: run = wandb.init(project='ndft_qm9') + elif opts.md17 > 0: + run = wandb.init(project='md17') else: run = wandb.init(project='ndft') opts.name = run.name @@ -588,13 +646,17 @@ def nanoDFT(mol_str, opts): d_model= 1024 n_heads = 16 n_layers = 24 - if opts.large: - d_model= 1280 + if opts.large: # this is 600M; + d_model= 1280 # 80*16 n_heads = 16 n_layers = 36 - if opts.xlarge: - d_model= 1600 - n_heads = 25 + if opts.largep: # interpolated between large and largep. + d_model= 91*16 # halway from 80 to 100 + n_heads = 16*1 + n_layers = 43 + if opts.xlarge: # this is 1.3B; decrease parameter count 30%. + d_model= 1600 # 100*16 + n_heads = 25 n_layers = 48 if opts.nn: @@ -606,6 +668,7 @@ def nanoDFT(mol_str, opts): n_heads =n_heads, d_ff =d_model*4, ) + print("[%.4fs] initialized transformer. "%(time.time()-start_time) ) params = params.to_float32() if opts.resume: @@ -616,30 +679,35 @@ def nanoDFT(mol_str, opts): if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer - - # try to mimic karpathy as closely as possible ;) - # https://github.com/karpathy/nanoGPT/blob/master/train.py - # still differs on - # [ ] weight initialization - - def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.min_lr, warmup_iters=opts.warmup_iters, lr_decay_iters=opts.lr_decay): cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters cond2 = (it > lr_decay_iters) * min_lr decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) - return cond1 + cond2 + cond3 + if not opts.resume: return cond1 + cond2 + cond3 + else: return learning_rate adam = optax.chain( optax.clip_by_global_norm(1), - optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + #optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + optax.scale_by_adam(b1=0.99, b2=0.999, eps=1e-12), + #optax.scale_by_factored_rms(), # use this for larger model (more memory efficient) optax.add_decayed_weights(0.1), optax.scale_by_schedule(custom_schedule), optax.scale(-1), + #optax.ema(opts.ema) if opts.ema != 0 else None ) - w = params + df = None + if opts.qh9: + df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") + df = df[df["N_sto3g"]==55] + print(df.shape) + elif opts.qm9: + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] from torch.utils.data import DataLoader, Dataset class OnTheFlyQM9(Dataset): @@ -647,31 +715,20 @@ class OnTheFlyQM9(Dataset): # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples - def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): - - if opts.qh9: - df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") - if nao != -1: df = df[df["N_sto3g"]==55] - print(df.shape) - else: - # only take molecules with use {CNOFH}, nao=nao and spin=0. - #qm9 - df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules - if nao != -1: df = df[df["nao"]==nao] - - + def __init__(self, opts, df=None, nao=294, train=True, num_epochs=10**9, extrapolate=False): # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. # this shuffling is important, because it makes the last 10 samples iid (used for validation) #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? - - if train: self.mol_strs = df["pyscf"].values[:-10] - else: self.mol_strs = df["pyscf"].values[-10:] + if opts.qh9 or opts.qm9: + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] #print(df["pyscf"].) # todo: print smile strings self.num_epochs = num_epochs self.opts = opts self.validation = not train self.extrapolate = extrapolate + self.do_pyscf = self.validation or self.extrapolate self.benzene = [ ["C", ( 0.0000, 0.0000, 0.0000)], @@ -697,39 +754,58 @@ def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=Fals ] if opts.benzene: self.mol_strs = [self.benzene] - if opts.waters: self.mol_strs = [self.waters] + + if opts.md17 > 0: + mol = {MD17_WATER: "water", MD17_ALDEHYDE: "malondialdehyde", MD17_ETHANOL: "ethanol", MD17_URACIL: "uracil"}[opts.md17] + mode = {True: "train", False: "val"}[train] + filename = "md17/%s_%s.pickle"%(mode, mol) + df = pd.read_pickle(filename) + + self.mol_strs = df["pyscf"].values.tolist() + N = int(np.sqrt(df["H"].values.tolist()[0].reshape(-1).size)) + self.H = [a.reshape(N, N) for a in df["H"].values.tolist()] + self.E = df["E"].values.tolist() + self.mol_strs = [eval(a) for a in self.mol_strs] + else: + self.H = [0 for _ in self.mol_strs] + self.E = [0 for _ in self.mol_strs] + + if opts.alanine: self.mol_strs = mol_str if train: self.bs = opts.bs else: self.bs = opts.val_bs - def __len__(self): - return len(self.mol_strs)*self.num_epochs + + def __len__(self): return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ - wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ - extrapolate=self.extrapolate, mol_idx=idx) + wiggle_num=0, do_pyscf=self.do_pyscf, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx), self.H[idx%len(self.mol_strs)], self.E[idx%len(self.mol_strs)] - val_qm9 = OnTheFlyQM9(opts, train=False) - ext_qm9 = OnTheFlyQM9(opts, extrapolate=True) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + val_qm9 = OnTheFlyQM9(opts, train=False, df=df) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + ext_qm9 = OnTheFlyQM9(opts, extrapolate=True, df=df) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) - # parallel dataloader bug; precompute here is not slow but causes dataloader later to die. - # run once to quickly precompute. if opts.precompute: val_state = val_qm9[0] ext_state = ext_qm9[0] exit() - qm9 = OnTheFlyQM9(opts, train=True) + qm9 = OnTheFlyQM9(opts, train=True, df=df) + print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) if opts.workers != 0: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) + print("[%.4fs] initialized dataloaders. "%(time.time()-start_time) ) if opts.test_dataloader: t0 = time.time() - for iteration, state in enumerate(pbar): + for iteration, (state, H, E) in enumerate(pbar): if iteration == 0: summary(state) print(time.time()-t0) t0 = time.time() @@ -750,9 +826,10 @@ def __next__(self): return self.item adam = optax.adabelief(opts.lr) summary(states[0]) - vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) - valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn')) + vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) adam_state = adam.init(w) + print("[%.4fs] jitted vandg and valf."%(time.time()-start_time) ) if opts.resume: print("loading adam state") @@ -760,6 +837,7 @@ def __next__(self): return self.item print("done") w, adam_state = jax.device_put(w), jax.device_put(adam_state) + print("[%.4fs] jax.device_put(w,adam_state)."%(time.time()-start_time) ) @partial(jax.jit, backend=opts.backend) @@ -781,7 +859,9 @@ def update(w, adam_state, accumulated_grad): paddings = [] states = [] - for iteration, state in enumerate(pbar): + + print("[%.4fs] first iteration."%(time.time()-start_time) ) + for iteration, (state, H, E) in enumerate(pbar): if iteration == 0: summary(state) state = jax.device_put(state) @@ -800,59 +880,65 @@ def update(w, adam_state, accumulated_grad): if opts.shuffle: random.shuffle(states) # load_time, t0 = time.time()-t0, time.time() - if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: - t0 = time.time() - try: - name = opts.name.replace("-", "_") - path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) - path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) - print("trying to checkpoint to %s and %s"%(path_model, path_adam)) - pickle.dump(jax.device_get(w), open(path_model, "wb")) - pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) - print("done!") - print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) - except: - print("fail!") - pass - print("tried saving model took %fs"%(time.time()-t0)) - save_time, t0 = time.time()-t0, time.time() - - - if len(states) < 50: print(len(states)) - for j, state in enumerate(states): - print(". ", end="", flush=True) - if j == 0: _t0 =time.time() - (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn) - print(",", end="", flush=True) - if j == 0: time_step1 = time.time()-_t0 - - if opts.grad_acc == 0 or len(states) < opts.mol_repeats: - w, adam_state = update(w, adam_state, grad) - else: - accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - if j % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. - w, adam_state = update(w, adam_state, grad) - accumulated_grad = None - print("#", end="", flush=True) + if len(states) < 50: print(len(states), opts.name) + + reps = 1 + if opts.md17 == 4: reps = 5 + if opts.md17 == 3: reps = 2 + if opts.md17 == 2: reps = 2 + + for _ in range(reps): + for j, state in enumerate(states): + print(". ", end="", flush=True) + if j == 0: _t0 =time.time() + (val, (vals, E_xc, density_matrix, _W, _H)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) + print(",", end="", flush=True) + if j == 0: time_step1 = time.time()-_t0 + + if opts.grad_acc == 0 or len(states) < opts.mol_repeats: + w, adam_state = update(w, adam_state, grad) + else: + accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) + + if (j+1) % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. + w, adam_state = update(w, adam_state, grad) + accumulated_grad = None + print("#", end="", flush=True) + + + if opts.checkpoint != -1 and adam_state[1].count % opts.checkpoint == 0 and adam_state[1].count > 0: + t0 = time.time() + try: + name = opts.name.replace("-", "_") + path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) + path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) + print("trying to checkpoint to %s and %s"%(path_model, path_adam)) + pickle.dump(jax.device_get(w), open(path_model, "wb")) + pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) + print("done!") + print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) + except: + print("fail!") + pass + print("tried saving model took %fs"%(time.time()-t0)) + save_time, t0 = time.time()-t0, time.time() + # todo: rename global_batch_size = len(states)*opts.bs if opts.wandb: dct["global_batch_size"] = global_batch_size train_time, t0 = time.time()-t0, time.time() - - # plot grad norm - #if iteration % 10 == 0: - # for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) update_time, t0 = time.time()-t0, time.time() if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) else: - pbar.set_description("train=".join(["%.2f"%i for i in vals[:1]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) + #print(vals[0], E) + pbar.set_description("train=%.4f"%(vals[0]*HARTREE_TO_EV) + "[eV] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) if opts.wandb: dct["time_load"] = load_time @@ -860,44 +946,66 @@ def update(w, adam_state, accumulated_grad): dct["time_train"] = train_time dct["time_val"] = val_time plot_iteration = iteration % 10 == 0 - for i in range(0, 2): - if not opts.nn: - dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) - dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) - dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) - if plot_iteration: - dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) - dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) + + dct["train_E"] = np.abs(E*HARTREE_TO_EV) + dct["train_E_pred"] = np.abs(vals[0]*HARTREE_TO_EV) step = adam_state[1].count plot_time, t0 = time.time()-t0, time.time() - - - # TODO: Plot molecules and val/ext angles. if opts.nn and (iteration < 250 or iteration % 10 == 0): - if val_state is None: val_state = jax.device_put(val_qm9[0]) - _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn) - if ext_state is None: ext_state = jax.device_put(ext_qm9[0]) - _, (ext_vals, _, edensity_matrix, eW) = valf(w, ext_state, opts.normal, opts.nn) + val_idx = 1 + if val_state is None: val_state, val_H, val_E = jax.device_put(val_qm9[val_idx]) # todo: cat 8 of these. + _, (valid_vals, _, vdensity_matrix, vW, _val_H) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) + + if opts.md17 > 0: + def get_H_from_dm(dm): + import pyscf + from pyscf import gto, dft + m = pyscf.gto.Mole(atom=val_qm9.mol_strs[val_idx], basis="def2-svp", unit="bohr") + m.build() + mf = dft.RKS(m) + mf.xc = 'B3LYP5' + mf.verbose = 0 + mf.diis_space = 8 + mf.conv_tol = 1e-13 + mf.grad_tol = 3.16e-5 + mf.grids.level = 3 + #mf.kernel() + h_core = mf.get_hcore() + S = mf.get_ovlp() + vxc = mf.get_veff(m, dm) + H = h_core + vxc + S = mf.get_ovlp() + return H, S + + matrix = np.array(vdensity_matrix[0]) + N = int(np.sqrt(matrix.size)) + _val_H, S = get_H_from_dm(matrix.reshape(N, N)) + + # compare eigenvalues + pred_vals = scipy.linalg.eigh(_val_H, S)[0] + label_vals = scipy.linalg.eigh(val_H, S)[0] + MAE_vals = np.mean(np.abs(pred_vals - label_vals)) + dct["val_eps"] = MAE_vals + lr = custom_schedule(step) - valid_str = "lr=%.3e"%lr + "val=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" - valid_str += "ext=" + "".join(["%.4f "%(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + valid_str = "lr=%.3e"%lr + "val=%.4f [eV] "%(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV) + "mae_H=%.4f "%( + np.mean(np.abs(val_H/np.abs(val_H) - _val_H/np.abs(_val_H))) + ) + if opts.md17> 0:valid_str+= " eps=%.4f"%(MAE_vals) + valid_str += "val'=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + + dct['val_E'] = np.abs(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV ) + dct['val_H_MAE'] = np.mean(np.abs(val_H - _val_H)) # perhaps sign doesn't matter? + if opts.wandb: - for i in range(0, opts.val_bs): + for i in range(0, 3): dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) - dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) - dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) - - dct['ext_l%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) - dct['ext_E%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV) - dct['ext_pyscf%i'%i ] = np.abs(ext_state.pyscf_E[i]) - dct['img/ext_dm%i'%i] = wandb.Image(np.expand_dims(edensity_matrix[i], axis=-1)) - dct['img/ext_W%i'%i] = wandb.Image(np.expand_dims(eW[i], axis=-1)) dct["scheduled_lr"] = lr @@ -1104,8 +1212,8 @@ def grids_from_pyscf_mol( def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): do_print = False #t0 = time.time() - mol = build_mol(mol_str, opts.basis) - if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + mol = build_mol(mol_str, opts.basis, unit="bohr") + if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts, unit="bohr") else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) @@ -1366,12 +1474,19 @@ def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> def pyscf_reference(mol_str, opts): from pyscf import __config__ __config__.dft_rks_RKS_grids_level = opts.level - mol = build_mol(mol_str, opts.basis) + mol = build_mol(mol_str, opts.basis, unit="bohr") mol.max_cycle = 50 mf = pyscf.scf.RKS(mol) - mf.max_cycle = 50 - mf.xc = "b3lyp5" + #mf.max_cycle = 50 + #mf.xc = "b3lyp5" + #mf.diis_space = 8 + mf.xc = 'B3LYP5' + mf.verbose = 0 # put this to 4 to check i set parameters correctly! mf.diis_space = 8 + # options from qh9 + mf.conv_tol=1e-13 + mf.grad_tol=3.16e-5 + mf.grids.level = 3 pyscf_energies = [] pyscf_hlgaps = [] lumo = mol.nelectron//2 @@ -1410,23 +1525,23 @@ def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, cosine_similarity = dot_products / (norm_X * norm_Y) print("Force cosine similarity:",cosine_similarity) -def build_mol(mol_str, basis_name): +def build_mol(mol_str, basis_name, unit="bohr"): mol = pyscf.gto.mole.Mole() - mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + mol.build(atom=mol_str, unit=unit, basis=basis_name, spin=0, verbose=0) return mol -def reference(mol_str, opts): +def reference(mol_str, opts, unit="bohr"): import pickle import hashlib if opts.skip: return np.zeros(1), np.zeros(1), np.zeros(1) - filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level) + unit).encode('utf-8')).hexdigest() print(filename) if not os.path.exists(filename): pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) with open(filename, "wb") as file: - pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces, unit], file) else: - pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + pyscf_E, pyscf_hlgap, pyscf_forces, unit = pickle.load(open(filename, "rb")) return pyscf_E, pyscf_hlgap, pyscf_forces @@ -1441,10 +1556,15 @@ def reference(mol_str, opts): # GD options parser.add_argument('-backend', type=str, default="cpu") - parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-lr', type=float, default=5e-4) + parser.add_argument('-min_lr', type=float, default=1e-7) + parser.add_argument('-warmup_iters', type=float, default=1000) + parser.add_argument('-lr_decay', type=float, default=200000) + parser.add_argument('-ema', type=float, default=0.0) + parser.add_argument('-steps', type=int, default=100000) parser.add_argument('-bs', type=int, default=8) - parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=3) parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. parser.add_argument('-grad_acc', type=int, default=0) # integer, deciding how many steps to accumulate. parser.add_argument('-shuffle', action="store_true") # whether to to shuffle the window of states each step. @@ -1462,14 +1582,16 @@ def reference(mol_str, opts): parser.add_argument('-skip', action="store_true", help="skip pyscf test case") # dataset - parser.add_argument('-nperturb', type=int, default=1, help="how many atoms to perturb (supports 1,2,3)") + parser.add_argument('-nperturb', type=int, default=0, help="How many atoms to perturb (supports 1,2,3)") parser.add_argument('-qm9', action="store_true") + parser.add_argument('-md17', type=int, default=-1) parser.add_argument('-qh9', action="store_true") parser.add_argument('-benzene', action="store_true") parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") parser.add_argument('-waters', action="store_true") parser.add_argument('-alanine', action="store_true") + parser.add_argument('-do_print', action="store_true") # useful for debugging. parser.add_argument('-states', type=int, default=1) parser.add_argument('-workers', type=int, default=5) parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. @@ -1479,6 +1601,8 @@ def reference(mol_str, opts): parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") parser.add_argument('-test_dataloader', action="store_true", help="no training, just test/loop through dataloader. ") + + # models parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") parser.add_argument('-tiny', action="store_true") @@ -1487,23 +1611,31 @@ def reference(mol_str, opts): parser.add_argument('-medium', action="store_true") parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") + parser.add_argument('-largep', action="store_true") # large "plus" parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model opts = parser.parse_args() if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True - assert opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." + assert opts.grad_acc == 0 or opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." + + class HashableNamespace: + def __init__(self, namespace): self.__dict__.update(namespace.__dict__) + def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) + opts = HashableNamespace(opts) args_dict = vars(opts) print(args_dict) - if opts.qm9: df = pd.read_pickle("alchemy/atom_9.pickle") df = df[df["spin"] == 0] # only consider spin=0 mol_strs = df["pyscf"].values + if opts.qh9: + mol_strs = [] + # benzene if opts.benzene: mol_strs = [[ @@ -1526,7 +1658,7 @@ def reference(mol_str, opts): ["H", ( 0.0000, 0.0000, 0.0000)], ["H", ( 1.4000, 0.0000, 0.0000)], ]] - if opts.water: + if opts.md17 > 0 : mol_strs = [[ ["O", ( 0.0000, 0.0000, 0.0000)], ["H", ( 0.0000, 1.4000, 0.0000)], @@ -1573,4 +1705,4 @@ def reference(mol_str, opts): exit() pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) - print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) \ No newline at end of file From 071e17768595386da2a3b3562ff98c64cf7606f2 Mon Sep 17 00:00:00 2001 From: Adam Krzywaniak Date: Sat, 20 Jan 2024 17:28:33 +0000 Subject: [PATCH 22/22] code for paper plots --- pyscf_ipu/direct/another_plot.py | 79 ++ pyscf_ipu/direct/inference_heatmap_plot.py | 303 ++++++++ .../direct/inference_heatmap_plot_small_bs.py | 311 ++++++++ pyscf_ipu/direct/plot_heatmap_for_paper.py | 44 ++ pyscf_ipu/direct/train.py | 734 +++++++----------- pyscf_ipu/direct/transformer.py | 72 +- 6 files changed, 1035 insertions(+), 508 deletions(-) create mode 100644 pyscf_ipu/direct/another_plot.py create mode 100644 pyscf_ipu/direct/inference_heatmap_plot.py create mode 100644 pyscf_ipu/direct/inference_heatmap_plot_small_bs.py create mode 100644 pyscf_ipu/direct/plot_heatmap_for_paper.py diff --git a/pyscf_ipu/direct/another_plot.py b/pyscf_ipu/direct/another_plot.py new file mode 100644 index 00000000..0b73ffe4 --- /dev/null +++ b/pyscf_ipu/direct/another_plot.py @@ -0,0 +1,79 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt + + +ml_file = "heatmap_data_009.pkl" +pyscf_file = "heatmap_pyscf_009.pkl" +# Load data from the pickle file +with open(ml_file, 'rb') as file: + data_list = pickle.load(file) + +with open(pyscf_file, 'rb') as file: + pyscf_list = pickle.load(file) + +# Extract phi, psi, and values from the loaded data +phi_values, psi_values, heatmap_val = zip(*data_list) + +# Extract phi, psi, and values from the loaded data +phi_values_p, psi_values_p, heatmap_pyscf = zip(*pyscf_list) + +matrix_size = int(len(data_list) ** 0.5) + +heatmap_val = np.array(heatmap_val).reshape(matrix_size, matrix_size) +heatmap_pyscf = np.array(heatmap_pyscf).reshape(matrix_size, matrix_size) + +# valid_E = NN(molecule) \approx E +# state.pyscf_E = DFT(molecule) = E +# state.valid_l = | NN(molecule) - DFT(molecule) | +# +heatmap_pyscf = -heatmap_pyscf + +phi_coordinates, psi_coordinates = np.meshgrid(np.linspace(min(phi_values), max(phi_values), matrix_size), + np.linspace(min(psi_values), max(psi_values), matrix_size)) + +fig, ax = plt.subplots(2,3, figsize=(10, 8)) +# im = ax[0,0].imshow( heatmap_val ) +im = ax[0,0].imshow(heatmap_val, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +# ax[0,0].set_xlim(phi_values) +# ax[0,0].set_ylim(psi_values) +im2 = ax[0,1].imshow( heatmap_pyscf, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +diff = ax[0,2].imshow( np.abs(heatmap_val - heatmap_pyscf), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +log = ax[1,0].imshow( np.log(np.abs(heatmap_val )), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +log2 = ax[1,1].imshow( np.log(np.abs(heatmap_pyscf )), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +difflog = ax[1,2].imshow( np.log(np.abs((heatmap_val - heatmap_pyscf))), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +for i in range(3): + for j in range(2): + ax[j, i].set_xticks(np.arange(phi_values[0], phi_values[-1], 45)) + ax[j, i].set_yticks(np.arange(psi_values[0], psi_values[-1], 45)) + # ax[j, i].set_xlim([phi_values[0], phi_values[-1]]) + # ax[j, i].set_ylim([psi_values[0], psi_values[-1]]) + ax[j, i].set_xlabel("phi [deg]") + ax[j, i].set_ylabel("psi [deg]") + +# orient = 'vertical' +orient = 'horizontal' +cbar = fig.colorbar(im, ax=ax[0, 0], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(im2, ax=ax[0, 1], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(diff, ax=ax[0, 2], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(log, ax=ax[1, 0], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(log2, ax=ax[1, 1], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(difflog, ax=ax[1, 2], orientation=orient, fraction=0.05, pad=0.28) + +# for a in ax.reshape(-1): a.axis("off") +ax[0,0].set_title("NN Energy") +ax[0,1].set_title("PySCF Energy") +ax[0,2].set_title("|NN-PySCF| Energy") + +ax[1,0].set_title("NN log(|Energy|)") +ax[1,1].set_title("PySCF log(|Energy|)") +ax[1,2].set_title("|NN-PySCF| log(|Energy|)") +# ax[0,0].set_ylabel("Energy") # may fail with axis("off") +# ax[1,0].set_ylabel("log(|Energy|)") # may fail with axis("off") +plt.tight_layout() + +# Save the plot to a PNG file +plt.savefig("poc.png") \ No newline at end of file diff --git a/pyscf_ipu/direct/inference_heatmap_plot.py b/pyscf_ipu/direct/inference_heatmap_plot.py new file mode 100644 index 00000000..cd86d5b5 --- /dev/null +++ b/pyscf_ipu/direct/inference_heatmap_plot.py @@ -0,0 +1,303 @@ +import pickle +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-basis', type=str, default="sto3g") +parser.add_argument('-level', type=int, default=0) + +# GD options +parser.add_argument('-backend', type=str, default="cpu") +parser.add_argument('-lr', type=float, default=2.5e-4) +parser.add_argument('-steps', type=int, default=100000) +parser.add_argument('-bs', type=int, default=8) +parser.add_argument('-val_bs', type=int, default=8) +parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + +# energy computation speedups +parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) +parser.add_argument('-xc_f32', action="store_true") +parser.add_argument('-eri_f32', action="store_true") +parser.add_argument('-eri_bs', type=int, default=8) + +parser.add_argument('-normal', action="store_true") +parser.add_argument('-wandb', action="store_true") +parser.add_argument('-prof', action="store_true") +parser.add_argument('-visualize', action="store_true") +parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + +# dataset +parser.add_argument('-qm9', action="store_true") +parser.add_argument('-benzene', action="store_true") +parser.add_argument('-hydrogens', action="store_true") +parser.add_argument('-water', action="store_true") +parser.add_argument('-waters', action="store_true") +parser.add_argument('-alanine', action="store_true") +parser.add_argument('-states', type=int, default=1) +parser.add_argument('-workers', type=int, default=5) +parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase +parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") +parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") +parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + +# models +parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") +parser.add_argument('-tiny', action="store_true") +parser.add_argument('-small', action="store_true") +parser.add_argument('-base', action="store_true") +parser.add_argument('-medium', action="store_true") +parser.add_argument('-large', action="store_true") +parser.add_argument('-xlarge', action="store_true") + +parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model +parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + +# inference heatmap plot args +parser.add_argument("-heatmap_step", type=int, default=10) +parser.add_argument("-plot_range", type=int, default=360) +opts = parser.parse_args() + +assert opts.val_bs * opts.heatmap_step == opts.plot_range, "[Temporary dependency] Try adjusting VAL_BS and HEATMAP_STEP so that their product is equal to PLOT_RANGE (by default 360)" + +if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + +if opts.alanine: + mol_str = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + +B, BxNxN, BxNxK = None, None, None +cfg = None +from train import dm_energy + +from transformer import transformer_init +from train import nao +# global cfg +'''Model ViT model embedding #heads #layers #params training throughput +dimension resolution (im/sec) +DeiT-Ti N/A 192 3 12 5M 224 2536 +DeiT-S N/A 384 6 12 22M 224 940 +DeiT-B ViT-B 768 12 12 86M 224 292 +Parameters Layers dmodel +117M 12 768 +345M 24 1024 +762M 36 1280 +1542M 48 1600 +''' +if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 +if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 +if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 +if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 +if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 +if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + +if opts.nn: + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + +# vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) +valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + +from train import batched_state +from torch.utils.data import DataLoader, Dataset +class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False, init_phi_psi = None): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + import pandas as pd + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + self.init_phi_psi = init_phi_psi + + # self.benzene = [ + # ["C", ( 0.0000, 0.0000, 0.0000)], + # ["C", ( 1.4000, 0.0000, 0.0000)], + # ["C", ( 2.1000, 1.2124, 0.0000)], + # ["C", ( 1.4000, 2.4249, 0.0000)], + # ["C", ( 0.0000, 2.4249, 0.0000)], + # ["C", (-0.7000, 1.2124, 0.0000)], + # ["H", (-0.5500, -0.9526, 0.0000)], + # ["H", (-0.5500, 3.3775, 0.0000)], + # ["H", ( 1.9500, -0.9526, 0.0000)], + # ["H", (-1.8000, 1.2124, 0.0000)], + # ["H", ( 3.2000, 1.2124, 0.0000)], + # ["H", ( 1.9500, 3.3775, 0.0000)] + # ] + # self.waters = [ + # ["O", (-1.464, 0.099, 0.300)], + # ["H", (-1.956, 0.624, -0.340)], + # ["H", (-1.797, -0.799, 0.206)], + # ["O", ( 1.369, 0.146, -0.395)], + # ["H", ( 1.894, 0.486, 0.335)], + # ["H", ( 0.451, 0.165, -0.083)] + # ] + + # if opts.benzene: self.mol_strs = [self.benzene] + # if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx, init_phi_psi = self.init_phi_psi, inference=True, inference_psi_step=opts.heatmap_step) + + +print("loading checkpoint") +weights = pickle.load(open("%s_model.pickle"%opts.resume, "rb")) +print("done loading. ") + +# print("loading adam state") +# adam_state = pickle.load(open("%s_adam_state.pickle"%opts.resume, "rb")) +# print("done") + +# weights, adam_state = jax.device_put(weights), jax.device_put(adam_state) +weights = jax.device_put(weights) + +from train import HashableNamespace + +# make `opts` hashable so that JAX will not complain about the static parameter that is passed as arg +opts = HashableNamespace(opts) + +data = [] +pyscf = [] +# data.append((1,1,344)) +# data.append((2,4,323)) +# data.append((3,3,334)) +# data.append((4,2,331)) + +for phi in range(0, opts.plot_range, opts.heatmap_step): + for psi in range(0, opts.plot_range, opts.val_bs * opts.heatmap_step): + val_qm9 = OnTheFlyQM9(opts, train=False, init_phi_psi=(phi, psi)) + val_state = jax.device_put(val_qm9[0]) + # print("\n^^^^^^^^^^^\nJUST VAL QM9 [0]:", val_qm9[0]) + # print("WHOLE VAL QM9:", val_qm9) + print("VAL_QM9[0].pyscf_E:", val_qm9[0].pyscf_E) + _, (valid_vals, _, vdensity_matrix, vW) = valf(weights, val_state, opts.normal, opts.nn, cfg, opts) + + valid_l = np.abs(valid_vals*HARTREE_TO_EV-val_state.pyscf_E) + valid_E = np.abs(valid_vals*HARTREE_TO_EV) + + print("valid_l: ", valid_l, "\nvalid_E: ", valid_E, "\nphi ", phi, " psi ", psi) + + for i in range(0, opts.val_bs): + data.append((phi, psi + i * opts.heatmap_step, valid_E[i])) + pyscf.append((phi, psi + i * opts.heatmap_step, val_state.pyscf_E[i].item())) + + # data.append((phi, psi, valid_E[0])) + +#data = np.log(np.abs(data)) +import matplotlib.pyplot as plt +from scipy.interpolate import griddata +# Extract phi, psi, and values from the data +phi_values, psi_values, heatmap_values = zip(*data) + +# Define a grid +phi_grid, psi_grid = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_grid, psi_grid), method='cubic', fill_value=0) + + +# Create a filled contour plot +plt.contourf(psi_grid, phi_grid, heatmap_interpolated, cmap='viridis', levels=100) +plt.colorbar(label='Intensity') + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap with Interpolation') + +# Save the plot to a PNG file +plt.savefig('heatmap_plot.png') + +# Show the plot +plt.show() + +import pickle + +print("DATA ML", data) +print("DATA PYSCF", pyscf) +# Save data to a pickle file +with open('heatmap_data.pkl', 'wb') as file: + pickle.dump(data, file) + + +# Save pyscf to a pickle file +with open('heatmap_pyscf.pkl', 'wb') as file: + pickle.dump(pyscf, file) \ No newline at end of file diff --git a/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py b/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py new file mode 100644 index 00000000..6bcbfa9d --- /dev/null +++ b/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py @@ -0,0 +1,311 @@ +import pickle +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-basis', type=str, default="sto3g") +parser.add_argument('-level', type=int, default=0) + +# GD options +parser.add_argument('-backend', type=str, default="cpu") +parser.add_argument('-lr', type=float, default=2.5e-4) +parser.add_argument('-steps', type=int, default=100000) +parser.add_argument('-bs', type=int, default=8) +parser.add_argument('-val_bs', type=int, default=8) +parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + +# energy computation speedups +parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) +parser.add_argument('-xc_f32', action="store_true") +parser.add_argument('-eri_f32', action="store_true") +parser.add_argument('-eri_bs', type=int, default=8) + +parser.add_argument('-normal', action="store_true") +parser.add_argument('-wandb', action="store_true") +parser.add_argument('-prof', action="store_true") +parser.add_argument('-visualize', action="store_true") +parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + +# dataset +parser.add_argument('-qm9', action="store_true") +parser.add_argument('-benzene', action="store_true") +parser.add_argument('-hydrogens', action="store_true") +parser.add_argument('-water', action="store_true") +parser.add_argument('-waters', action="store_true") +parser.add_argument('-alanine', action="store_true") +parser.add_argument('-states', type=int, default=1) +parser.add_argument('-workers', type=int, default=5) +parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase +parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") +parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") +parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + +# models +parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") +parser.add_argument('-tiny', action="store_true") +parser.add_argument('-small', action="store_true") +parser.add_argument('-base', action="store_true") +parser.add_argument('-medium', action="store_true") +parser.add_argument('-large', action="store_true") +parser.add_argument('-xlarge', action="store_true") + +parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model +parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + +# inference heatmap plot args +parser.add_argument("-heatmap_step", type=int, default=10) +parser.add_argument("-plot_range", type=int, default=360) +opts = parser.parse_args() + +# assert opts.val_bs * opts.heatmap_step == opts.plot_range, "[Temporary dependency] Try adjusting VAL_BS and HEATMAP_STEP so that their product is equal to PLOT_RANGE (by default 360)" +assert (opts.plot_range % (opts.val_bs * opts.heatmap_step)) == 0, "batch * step will not fit within the range with integer number of subranges" +if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + +if opts.alanine: + mol_str = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + +B, BxNxN, BxNxK = None, None, None +cfg = None +from train import dm_energy + +from transformer import transformer_init +from train import nao +# global cfg +'''Model ViT model embedding #heads #layers #params training throughput +dimension resolution (im/sec) +DeiT-Ti N/A 192 3 12 5M 224 2536 +DeiT-S N/A 384 6 12 22M 224 940 +DeiT-B ViT-B 768 12 12 86M 224 292 +Parameters Layers dmodel +117M 12 768 +345M 24 1024 +762M 36 1280 +1542M 48 1600 +''' +if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 +if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 +if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 +if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 +if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 +if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + +if opts.nn: + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + +# vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) +valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + +from train import batched_state +from torch.utils.data import DataLoader, Dataset +class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False, init_phi_psi = None): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + import pandas as pd + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + self.init_phi_psi = init_phi_psi + + # self.benzene = [ + # ["C", ( 0.0000, 0.0000, 0.0000)], + # ["C", ( 1.4000, 0.0000, 0.0000)], + # ["C", ( 2.1000, 1.2124, 0.0000)], + # ["C", ( 1.4000, 2.4249, 0.0000)], + # ["C", ( 0.0000, 2.4249, 0.0000)], + # ["C", (-0.7000, 1.2124, 0.0000)], + # ["H", (-0.5500, -0.9526, 0.0000)], + # ["H", (-0.5500, 3.3775, 0.0000)], + # ["H", ( 1.9500, -0.9526, 0.0000)], + # ["H", (-1.8000, 1.2124, 0.0000)], + # ["H", ( 3.2000, 1.2124, 0.0000)], + # ["H", ( 1.9500, 3.3775, 0.0000)] + # ] + # self.waters = [ + # ["O", (-1.464, 0.099, 0.300)], + # ["H", (-1.956, 0.624, -0.340)], + # ["H", (-1.797, -0.799, 0.206)], + # ["O", ( 1.369, 0.146, -0.395)], + # ["H", ( 1.894, 0.486, 0.335)], + # ["H", ( 0.451, 0.165, -0.083)] + # ] + + # if opts.benzene: self.mol_strs = [self.benzene] + # if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx, init_phi_psi = self.init_phi_psi, inference=True, inference_psi_step=opts.heatmap_step) + + +print("loading checkpoint") +weights = pickle.load(open("%s_model.pickle"%opts.resume, "rb")) +print("done loading. ") + +# print("loading adam state") +# adam_state = pickle.load(open("%s_adam_state.pickle"%opts.resume, "rb")) +# print("done") + +# weights, adam_state = jax.device_put(weights), jax.device_put(adam_state) +weights = jax.device_put(weights) + +from train import HashableNamespace + +# make `opts` hashable so that JAX will not complain about the static parameter that is passed as arg +opts = HashableNamespace(opts) + +data = [] +pyscf = [] +# data.append((1,1,344)) +# data.append((2,4,323)) +# data.append((3,3,334)) +# data.append((4,2,331)) + +valid_E = None +val_state = None +for phi in range(0, opts.plot_range, opts.heatmap_step): + # psi_start = 0 + # psi_end = psi_start + opts.val_bs * opts.heatmap_step + # while psi_end <= opts.plot_range: + # for psi in range(psi_start, psi_end, opts.heatmap_step): + for psi in range(0, opts.plot_range, opts.val_bs * opts.heatmap_step): + # print(psi, psi_start, psi_end, "<<<<<<<<<<<<<<<<<<") + val_qm9 = OnTheFlyQM9(opts, train=False, init_phi_psi=(phi, psi)) + val_state = jax.device_put(val_qm9[0]) + # print("\n^^^^^^^^^^^\nJUST VAL QM9 [0]:", val_qm9[0]) + # print("WHOLE VAL QM9:", val_qm9) + print("VAL_QM9[0].pyscf_E:", val_qm9[0].pyscf_E) + _, (valid_vals, _, vdensity_matrix, vW) = valf(weights, val_state, opts.normal, opts.nn, cfg, opts) + + valid_l = np.abs(valid_vals*HARTREE_TO_EV-val_state.pyscf_E) + valid_E = np.abs(valid_vals*HARTREE_TO_EV) + + print("valid_l: ", valid_l, "\nvalid_E: ", valid_E, "\nphi ", phi, " psi ", psi) + + for i in range(0, opts.val_bs): + data.append((phi, psi + i * opts.heatmap_step, valid_E[i])) + pyscf.append((phi, psi + i * opts.heatmap_step, val_state.pyscf_E[i].item())) + # psi_start = 0 + psi_end + # psi_end += opts.val_bs * opts.heatmap_step + # data.append((phi, psi, valid_E[0])) + +#data = np.log(np.abs(data)) +import matplotlib.pyplot as plt +from scipy.interpolate import griddata +# Extract phi, psi, and values from the data +phi_values, psi_values, heatmap_values = zip(*data) + +# Define a grid +phi_grid, psi_grid = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_grid, psi_grid), method='cubic', fill_value=0) + + +# Create a filled contour plot +plt.contourf(psi_grid, phi_grid, heatmap_interpolated, cmap='viridis', levels=100) +plt.colorbar(label='Intensity') + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap with Interpolation') + +# Save the plot to a PNG file +plt.savefig('heatmap_plot.png') + +# Show the plot +plt.show() + +import pickle + +print("DATA ML", data) +print("DATA PYSCF", pyscf) +# Save data to a pickle file +with open('heatmap_data_bs2.pkl', 'wb') as file: + pickle.dump(data, file) + + +# Save pyscf to a pickle file +with open('heatmap_pyscf_bs2.pkl', 'wb') as file: + pickle.dump(pyscf, file) \ No newline at end of file diff --git a/pyscf_ipu/direct/plot_heatmap_for_paper.py b/pyscf_ipu/direct/plot_heatmap_for_paper.py new file mode 100644 index 00000000..58fab939 --- /dev/null +++ b/pyscf_ipu/direct/plot_heatmap_for_paper.py @@ -0,0 +1,44 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt +from scipy.interpolate import griddata + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-data_file', type=str) +parser.add_argument('-output_name', type=str, default="default_output.png") +parser.add_argument('-log', type=bool, default=False) +opts = parser.parse_args() + +# Load data from the pickle file +with open(opts.data_file, 'rb') as file: + data_list = pickle.load(file) + + +# Extract phi, psi, and values from the loaded data +phi_values, psi_values, heatmap_values = zip(*data_list) + +if opts.log: + heatmap_values = np.log(np.abs(heatmap_values - np.mean(heatmap_values))) + +print(heatmap_values) +# Create a meshgrid of phi and psi coordinates +phi_coordinates, psi_coordinates = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) + +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_coordinates, psi_coordinates), method='cubic', fill_value=0) + +# Display the 2D matrix as an image +plt.imshow(heatmap_interpolated, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +plt.colorbar(label='Intensity') # Add colorbar with label + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap from Pickle File') + +# Save the plot to a PNG file +plt.savefig(opts.output_name) + +# Show the plot diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py index 18d1c5b8..fee2dbcf 100644 --- a/pyscf_ipu/direct/train.py +++ b/pyscf_ipu/direct/train.py @@ -1,9 +1,8 @@ import os -os.environ['OMP_NUM_THREADS'] = '16' +os.environ['OMP_NUM_THREADS'] = '8' import jax jax.config.update('jax_enable_x64', True) import jax.numpy as jnp -import scipy import numpy as np import pyscf import optax @@ -16,10 +15,7 @@ import math from functools import partial import pickle -import random -random.seed(42) -MD17_WATER, MD17_ETHANOL, MD17_ALDEHYDE, MD17_URACIL = 1, 2, 3, 4 cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 def T(x): return jnp.transpose(x, (0,2,1)) @@ -27,24 +23,20 @@ def T(x): return jnp.transpose(x, (0,2,1)) B, BxNxN, BxNxK = None, None, None # Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. -def dm_energy(W: BxNxK, state, normal, nn, cfg=None, opts=None): +def dm_energy(W: BxNxK, state, normal, nn, cfg, opts):#): if nn: - W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0))(cfg, \ - W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) - #W, state.ao_types, state.pos.astype(jnp.float64), state.H_core.astype(jnp.float64), state.L_inv.astype(jnp.float64)) + W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ + W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) W = W.astype(jnp.float64) # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! - H = state.H_core + W - L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ H @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO E_xc: B = exchange_correlation(density_matrix, state, normal, opts.xc_f32) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO - diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32, opts.bs) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) energy: float = jnp.sum(energies) - return energy, (energies, E_xc, density_matrix, W, H) - - + return energy, (energies, E_xc, density_matrix, W) def sparse_mult(values, dm, state, gsize): in_ = dm.take(state.cols, axis=0) @@ -62,7 +54,6 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): if False: main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) - # todo: remove state.grid_AO w/ sparsity tricks => reduce memory 10x. rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) rho = rho_a - rho_b @@ -76,7 +67,7 @@ def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) return E_xc -def JK(density_matrix, state, normal, jax_foriloop, eri_f32, bs): +def JK(density_matrix, state, normal, jax_foriloop, eri_f32): if normal: J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) @@ -86,43 +77,32 @@ def JK(density_matrix, state, normal, jax_foriloop, eri_f32, bs): if eri_f32: density_matrix = density_matrix.astype(jnp.float32) - if bs == 1: - diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.nonzero_indices[0], - density_matrix, - jax_foriloop - ) - - - else: - '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.nonzero_indices[0], - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.indxs, - density_matrix, - jax_foriloop - )''' - - diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( - state.nonzero_distinct_ERI[0], - state.precomputed_nonzero_indices, - density_matrix, - jax_foriloop - ) - if bs > 1: - correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( - state.diffs_ERI, - state.precomputed_indxs, - density_matrix, - jax_foriloop - ) - diff_JK: BxNxN = diff_JK - correction + '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.indxs, + density_matrix, + jax_foriloop + )''' + + diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.precomputed_nonzero_indices, + density_matrix, + jax_foriloop + ) + correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.precomputed_indxs, + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - correction return diff_JK.astype(jnp.float64) @@ -141,10 +121,10 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI=200000, pad_sparse_diff_grid=200000, mol_idx=42, + init_phi_psi=None, + inference=False, + inference_psi_step=5, # degrees ): - start_time = time.time() - do_print = opts.do_print - if do_print: print("\t[%.4fs] start of 'batched_state'. "%(time.time()-start_time)) # pad molecule if using nn. if not opts.nn: pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ @@ -160,19 +140,18 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, pad_nonzero_distinct_ERI = 20000 pad_sparse_diff_grid = 20000 - if opts.qm9 or opts.qh9: + if opts.qm9: pad_electrons=60 - padding_estimate = [48330, 163222, 17034, 159361, 139505] - if opts.nperturb == 2 or opts.nperturb == 1: padding_estimate = [int(a*2.1) for a in padding_estimate] - else: padding_estimate = [int(a*1.1) for a in padding_estimate] - + '''pad_diff_ERIs=120000 + pad_distinct_ERIs=400000 + pad_grid_AO=50000 + pad_nonzero_distinct_ERI=400000 + pad_sparse_diff_grid=400000''' + #padding_estimate = [37426, 149710, 17010, 140122, 138369] + padding_estimate = [48330, 163222, 17034, 159361, 139505] + padding_estimate = [int(a*1.1) for a in padding_estimate] pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - if opts.basis == "def2-svp": - # disable padding temporarily - max_pad_electrons, max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid = \ - -1, -1, -1, -1, -1, -1 - if opts.alanine: # todo: (adam) the ERI padding may change when rotating molecule more! pad_electrons = 70 @@ -191,39 +170,7 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, if opts.waters: pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] - if opts.md17 > 0: - if opts.md17 == MD17_WATER: - if opts.level == 1: padding_estimate = [ 3361, 5024, 10172 , 5024 , 155958] - if opts.level == 3: padding_estimate = [ 3361, 5024, 34310 , 5024 , 494370] - if opts.bs == 2 and opts.wiggle_var == 0: padding_estimate = [ 1, 5024, 10172 , 5024 , 1] - padding_estimate = [int(a*1.5) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - elif opts.md17 == MD17_ETHANOL: - pad_electrons = 72 - #padding_estimate = [ 99042, 197660, 34310*5 , 197308 , 494370*5] - #padding_estimate = [113522, 224275, 30348, 197308, 609163] - if opts.level == 1: padding_estimate = [ 1, 415186, 30348, 415145, 1] - padding_estimate = [int(a*1.1) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - elif opts.md17 == MD17_ALDEHYDE: - pad_electrons = 90 - #padding_estimate = [130646, 224386 , 30348, 223233, 626063] - #padding_estimate = [ 34074, 204235 , 30348, 203632, 285934] - if opts.level == 1: padding_estimate = [1, 939479, 35704, 939479, 1] - padding_estimate = [int(a*1.1) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - elif opts.md17 == MD17_URACIL: - #raise Exception("not ready yet") - pad_electrons = 132 - padding_estimate = [1, 3769764 , 51184, 3769764, 1] - padding_estimate = [int(a*1.05) for a in padding_estimate] - pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate - - - mol = build_mol(mol_str, opts.basis, unit="bohr") + mol = build_mol(mol_str, opts.basis) pad_electrons = min(pad_electrons, mol.nao_nr()) # Set seed to ensure different rotation; initially all workers did same rotation! @@ -233,53 +180,28 @@ def batched_state(mol_str, opts, bs, wiggle_num=0, water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) - - if opts.md17 > 0: - natm = len(mol_str) - atom_num = random.sample(range(natm), 1)[0] - atoms = np.array([mol_str[i][1] for i in range(0,natm)]) - - if opts.qm9 or opts.qh9: - # pick random atom to perturb (of the first 9 heavy ones) - atom_num1, atom_num2, atom_num3 = random.sample(range(9), 3) - atoms = np.array([mol_str[i][1] for i in range(0,10)]) - atom_type = [mol_str[i][0] for i in range(0,10)] + if opts.qm9: + atoms = np.array([mol_str[i][1] for i in range(0,3)]) + # pick random atom to permute (of the first 9 heavy ones) + atom_num = int(np.random.uniform(0, 8)) if opts.alanine: # train on [-180, 180], validate [-180, 180] extrapolate [-360, 360]\[180, -180] # todo: draw picture (in training loop) - if extrapolate: + if extrapolate and not inference: phi, psi = [float(a) for a in np.random.uniform(180, 360, 2)] - else: + elif inference: + phi, psi = init_phi_psi + else: phi, psi = [float(a) for a in np.random.uniform(0, 180, 2)] angles = [] - # Combinatorics of atom substitutions (n_electrons has to be even). - # P := single atom modification, n_electrons remain even (C-> O, O->C, N->F, F->N) - # B1, B2 := single atom modification, n_electrons becomes odd (C->{N,F}, O->{N,F}, N->{O,C}, F->{O,C}) - I = {"C":"C", "O":"O", "F":"F", "N":"N"} - P = {"C": "O", "O":"C", "F":"N", "N":"F"} - B1 = {"C":"N", "O":"N", "F":"O", "N":"O"} - B2 = {"C":"F", "O":"F", "F":"C", "N":"C"} - - if opts.nperturb == 3: - allowed_pertubations=[ - (I,I,I), (I,I,P), (I,P,I), (I,P,P), (P,I,I), (P,I,P), (P,P,I), (P,P,P), - (I,B1,B2), (B1,I,B2), (B1,B2,I), (I,B2,B1), (B2,I,B1), (B2,B1,I), (P,B1,B2), (B1,P,B2), (B1,B2,P), (P,B2,B1), (B2,P,B1), (B2,B1,P), - (I,B1,B1), (B1,I,B1), (B1,B1,I), (I,B1,B1), (B1,I,B1), (B1,B1,I), (P,B1,B1), (B1,P,B1), (B1,B1,P), (P,B1,B1), (B1,P,B1), (B1,B1,P), - (I,B2,B2), (B2,I,B2), (B2,B2,I), (I,B2,B2), (B2,I,B2), (B2,B2,I), (P,B2,B2), (B2,P,B2), (B2,B2,P), (P,B2,B2), (B2,P,B2), (B2,B2,P), - ] - if opts.nperturb == 2: - allowed_pertubations=[ (I, I, I), (I, I, P), (I, P, I), (I, P, P), (I, B1, B1), (I, B1, B2), (I, B2, B1), (I, B2, B2), ] - if opts.nperturb == 1: - allowed_pertubations=[ (I, I, I), (I, I, P), ]*10 - - if opts.nperturb: random.shuffle(allowed_pertubations) - states = [] for iteration in range(bs): - if do_print: print("\t[%.4fs] initializing state %i. "%(time.time()-start_time, iteration)) + import copy + new_str = copy.deepcopy(mol_str) + if opts.alanine: from rdkit import Chem from rdkit.Chem import AllChem @@ -296,86 +218,94 @@ def get_atom_positions(mol): conf = mol.GetConformer() return np.concatenate([xyz(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], axis=0) - str = [mol_str[j][0] for j in range(len(mol_str))] - pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) # todo: save=wandb.log({"pair": angle1, angle2, NN_energy ) (rotation, NN_energy) for train/val molecule (for val also save PySCF energy) # only saving angles (angle not paired up with energy) AllChem.SetDihedralDeg(molecule.GetConformer(), *phi_atoms, phi) - angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? + angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? # todo: check math whether val/extra/train have uniform distribution on their respective domains. - if extrapolate: # make sure angle is in [] + if extrapolate and not inference: # make sure angle is in [] angle = angle % 180 + 180 # angle should be in [180, 360] - else: + elif inference: + angle = psi + iteration * inference_psi_step # overwrite the angle when in inference mode with fixed, not randomized, step + angle = angle % 360 # angle should be [0, 360] for heatmap + else: # validation angle = angle % 180 # angle should be [0, 180] AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle ) pos = get_atom_positions(molecule) angles.append((phi, angle)) - for j in range(len(mol_str)): mol_str[j][1] = tuple(pos[j]) + for j in range(len(new_str)): new_str[j][1] = tuple(pos[j]) '''if iteration == 0 and opts.wandb: from plot import create_rdkit_mol import wandb wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) })''' - if opts.md17 > 0: - if iteration > 0: - mol_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0,opts.wiggle_var, (3))) - if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] center = water2_xyz.mean(axis=0) water_xyz = np.dot(water2_xyz - center, rotation_matrix) + center - mol_str[3][1] = tuple(water_xyz[0]) - mol_str[4][1] = tuple(water_xyz[1]) - mol_str[5][1] = tuple(water_xyz[2]) + new_str[3][1] = tuple(water_xyz[0]) + new_str[4][1] = tuple(water_xyz[1]) + new_str[5][1] = tuple(water_xyz[2]) '''if opts.wandb and iteration == 0: from plot import create_rdkit_mol import wandb - str = [mol_str[j][0] for j in range(len(mol_str))] - pos = np.concatenate([np.array(mol_str[j][1]).reshape(1, 3) for j in range(len(mol_str))]) + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' - elif opts.qm9 or opts.qh9: - if iteration == 0 and (validation or extrapolate): pass - else: - s = mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] - - # for small stuff, wiggle a single atom a bit. - if opts.nperturb <= 1: mol_str[atom_num3][1] = tuple(atoms[atom_num3] + np.random.normal(0, opts.wiggle_var, (3))) + elif opts.qm9: + # todo: find dihedral to rotate over similar to alanine dipeptide. + # broken; rotate first three atoms around their center of mass + # this breaks molecule; should use dihedral angle as done with the dipeptide. + #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] + #center = atoms.mean(axis=0) + #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center - # assuming all atoms are non-hydrogen; if wrong don't do any pertubations. - if opts.nperturb > 0 and mol_str[atom_num1][0] != "H" and mol_str[atom_num2][0] != "H" and mol_str[atom_num3][0] != "H": - A,B,C = allowed_pertubations[iteration] - mol_str[atom_num1][0] = A[atom_type[atom_num1]] - mol_str[atom_num2][0] = B[atom_type[atom_num2]] - mol_str[atom_num3][0] = C[atom_type[atom_num3]] + # for extrapolation, do even more. - s += "->" + mol_str[atom_num1][0]+ mol_str[atom_num2][0]+ mol_str[atom_num3][0] - if do_print: print(iteration, s) + if iteration == 0 and (validation or extrapolate): + pass + else: + #new_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) + new_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0, opts.wiggle_var, (3))) + #new_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) + #new_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) - # -nperturb 2 is a lot slower; changing self.prune=False roughly doubles AO matrices. + '''if opts.wandb and iteration == 0: + from plot import create_rdkit_mol + import wandb + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' if iteration == 0: - state = init_dft(mol_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) + state = init_dft(new_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) c, w = state.grid_coords, state.grid_weights elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting - state = init_dft(mol_str, opts, c, w, do_pyscf=do_pyscf and iteration < 3, state=state, pad_electrons=pad_electrons) + state = init_dft(new_str, opts, c, w, do_pyscf=do_pyscf and iteration < 80, state=state, pad_electrons=pad_electrons) states.append(state) - + # If we add energy here we get plot basically! + # todo: save and store in training loop, then we can match with energy + # can't get to work in wandb, but can just use download api and the plot. + '''if opts.alanine and opts.wandb: + for phi, psi in angles: + if not validation: + wandb.log({"phi_train": phi , "psi_train": psi}) + else: + wandb.log({"phi_valid": phi, "psi_valid": psi})''' state = cats(states) N = state.N[0] - if do_print: print("\t[%.4fs] concatenated states. "%(time.time()-start_time)) - - - #_nonzero = None # Compute ERI sparsity. nonzero = [] @@ -385,23 +315,16 @@ def get_atom_positions(mol): e[indxs] = 0 nonzero.append(np.nonzero(e)[0]) - #_nonzero = indxs if _nonzero is None else np.logical_or(_nonzero, indxs) - - if do_print: print("\t[%.4fs] got sparsity. "%(time.time()-start_time)) - # Merge nonzero indices and prepare (ij, kl). # rep is the number of repetitions we include in the sparse representation. + #nonzero_indices = np.union1d(nonzero[0], nonzero[1]) union = nonzero[0] - for i in range(1, len(nonzero)): # this takes 12s/it for def2-svp. + for i in range(1, len(nonzero)): union = np.union1d(union, nonzero[i]) nonzero_indices = union - if do_print: print("\t[%.4fs] got union of sparsity. "%(time.time()-start_time)) - - from sparse_symmetric_ERI import get_i_j, num_repetitions_fast ij, kl = get_i_j(nonzero_indices) rep = num_repetitions_fast(ij, kl) - if do_print: print("\t[%.4fs] got (ij) and reps. "%(time.time()-start_time)) batches = opts.eri_bs es = [] @@ -415,10 +338,8 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) - if do_print: print("\t[%.4fs] padded ERI and nonzero_indices. . "%(time.time()-start_time)) i, j = get_i_j(ij.reshape(-1)) k, l = get_i_j(kl.reshape(-1)) - if do_print: print("\t[%.4fs] got ijkl. "%(time.time()-start_time)) if remainder != 0: i = np.pad(i, ((0,batches-remainder))) @@ -427,15 +348,12 @@ def get_atom_positions(mol): l = np.pad(l, ((0,batches-remainder))) nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int32) # todo: use int16 or int32 here? state.nonzero_indices = nonzero_indices - if do_print: print("\t[%.4fs] padded and vstacked ijkl. "%(time.time()-start_time)) # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. # function below does this. # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). from sparse_symmetric_ERI import precompute_indices - - if opts.normal: diff_state = None else: main_grid_AO = state.grid_AO[:1] @@ -444,41 +362,39 @@ def get_atom_positions(mol): sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] # use the same sparsity pattern across a batch. - if opts.bs > 1: - diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI - diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) - nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 - diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) - diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) + diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) + diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) - remainder = np.sum(nzr) % batches - if remainder != 0: - diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) - diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) + remainder = np.sum(nzr) % batches + if remainder != 0: + diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) + diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) - diff_ERIs = diff_ERIs.reshape(bs, batches, -1) - diff_indxs = diff_indxs.reshape(batches, -1, 4) + diff_ERIs = diff_ERIs.reshape(bs, batches, -1) + diff_indxs = diff_indxs.reshape(batches, -1, 4) - if opts.bs > 1: precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) + precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) - if pad_diff_ERIs == -1: - state.indxs=diff_indxs - state.diffs_ERI=diff_ERIs - assert False, "deal with precomputed_indxs; only added in else branch below" - else: - max_pad_diff_ERIs = diff_ERIs.shape[2] - if do_print: print("\t[%.4fs] max_pad_diff_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) - # pad ERIs with 0 and indices with -1 so they point to 0. - assert diff_indxs.shape[1] == diff_ERIs.shape[2] - pad = pad_diff_ERIs - diff_indxs.shape[1] - assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) - state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) - state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros - #print(diff_indxs.shape, precomputed_indxs.shape) - if opts.bs > 1: state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) - - #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + if pad_diff_ERIs == -1: + state.indxs=diff_indxs + state.diffs_ERI=diff_ERIs + assert False, "deal with precomputed_indxs; only added in else branch below" + else: + max_pad_diff_ERIs = diff_ERIs.shape[2] + # pad ERIs with 0 and indices with -1 so they point to 0. + assert diff_indxs.shape[1] == diff_ERIs.shape[2] + pad = pad_diff_ERIs - diff_indxs.shape[1] + assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) + state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) + state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + #print(diff_indxs.shape, precomputed_indxs.shape) + state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + + #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) state.rows=rows state.cols=cols @@ -490,7 +406,6 @@ def get_atom_positions(mol): if pad_sparse_diff_grid != -1: max_pad_sparse_diff_grid = state.rows.shape[0] - if do_print: print("\t[%.4fs] max_pad_sparse_diff_grid=%i"%(time.time()-start_time, max_pad_sparse_diff_grid)) assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] pad = pad_sparse_diff_grid - state.rows.shape[0] @@ -509,7 +424,6 @@ def get_atom_positions(mol): # todo: looks like we're padding, then looking for zeros, then padding; this can be simplified. if pad_distinct_ERIs != -1: max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] - if do_print: print("\t[%.4fs] max_pad_distinct_ERIs=%i"%(time.time()-start_time, max_pad_diff_ERIs)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) @@ -520,7 +434,6 @@ def get_atom_positions(mol): if pad_grid_AO != -1: max_pad_grid_AO = state.grid_AO.shape[2] - if do_print: print("\t[%.4fs] max_pad_grid_AO=%i"%(time.time()-start_time, max_pad_grid_AO)) prev_size = state.grid_AO.shape[2] assert state.grid_AO.shape[2] == state.grid_weights.shape[1] @@ -555,12 +468,11 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) - if opts.bs > 1: precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) + precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) #print(state.nonzero_indices.shape, precomputed_nonzero_indices.shape) if pad_nonzero_distinct_ERI != -1: max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] - if do_print: print("\t[%.4fs] max_pad_nonzero_distinct_ERI=%i"%(time.time()-start_time, max_pad_nonzero_distinct_ERI)) assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] @@ -568,7 +480,7 @@ def get_atom_positions(mol): state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) - if opts.bs > 1: state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) + state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) #print(state.precomputed_nonzero_indices.shape, state.nonzero_indices.shape) #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) @@ -590,7 +502,6 @@ def get_atom_positions(mol): def nanoDFT(mol_str, opts): - start_time = time.time() print() # Initialize validation set. # This consists of DFT tensors initialized with PySCF/CPU. @@ -602,14 +513,10 @@ def nanoDFT(mol_str, opts): run = wandb.init(project='ndft_alanine') elif opts.qm9: run = wandb.init(project='ndft_qm9') - elif opts.md17 > 0: - run = wandb.init(project='md17') else: run = wandb.init(project='ndft') opts.name = run.name - wandb.log(vars(opts)) - else: opts.name = "%i"%time.time() @@ -646,17 +553,13 @@ def nanoDFT(mol_str, opts): d_model= 1024 n_heads = 16 n_layers = 24 - if opts.large: # this is 600M; - d_model= 1280 # 80*16 + if opts.large: + d_model= 1280 n_heads = 16 n_layers = 36 - if opts.largep: # interpolated between large and largep. - d_model= 91*16 # halway from 80 to 100 - n_heads = 16*1 - n_layers = 43 - if opts.xlarge: # this is 1.3B; decrease parameter count 30%. - d_model= 1600 # 100*16 - n_heads = 25 + if opts.xlarge: + d_model= 1600 + n_heads = 25 n_layers = 48 if opts.nn: @@ -668,7 +571,6 @@ def nanoDFT(mol_str, opts): n_heads =n_heads, d_ff =d_model*4, ) - print("[%.4fs] initialized transformer. "%(time.time()-start_time) ) params = params.to_float32() if opts.resume: @@ -679,35 +581,42 @@ def nanoDFT(mol_str, opts): if opts.nn: #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer - def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.min_lr, warmup_iters=opts.warmup_iters, lr_decay_iters=opts.lr_decay): + + # try to mimic karpathy as closely as possible ;) + # https://github.com/karpathy/nanoGPT/blob/master/train.py + # still differs on + # [ ] weight initialization + + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi + #return learning_rate * it / warmup_iters # to allow jax jit? + # allow jax jit + '''if it < warmup_iters: return learning_rate * it / warmup_iters # linearly increase until hit warmup iters. + if it > lr_decay_iters: return min_lr # after decay (600k iterations) go to 10x lower + + # in between, decay learning rate using this function; this is from 2k steps to 600k steps + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return min_lr + coeff * (learning_rate - min_lr)''' + #if it < warmup_iters: return learning_rate * it / warmup_iters cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters cond2 = (it > lr_decay_iters) * min_lr + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) - if not opts.resume: return cond1 + cond2 + cond3 - else: return learning_rate + return cond1 + cond2 + cond3 adam = optax.chain( optax.clip_by_global_norm(1), - #optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), - optax.scale_by_adam(b1=0.99, b2=0.999, eps=1e-12), - #optax.scale_by_factored_rms(), # use this for larger model (more memory efficient) - optax.add_decayed_weights(0.1), + optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + optax.add_decayed_weights(0.1),#, configure_decay_mask(params)), optax.scale_by_schedule(custom_schedule), optax.scale(-1), - #optax.ema(opts.ema) if opts.ema != 0 else None ) + w = params - df = None - if opts.qh9: - df = pd.read_pickle("qh9/qh9stable_processed_shuffled.pickle") - df = df[df["N_sto3g"]==55] - print(df.shape) - elif opts.qm9: - df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules - if nao != -1: df = df[df["nao"]==nao] from torch.utils.data import DataLoader, Dataset class OnTheFlyQM9(Dataset): @@ -715,20 +624,22 @@ class OnTheFlyQM9(Dataset): # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). # problem: second epoch always gives segfault. # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples - def __init__(self, opts, df=None, nao=294, train=True, num_epochs=10**9, extrapolate=False): + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. # this shuffling is important, because it makes the last 10 samples iid (used for validation) #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? - if opts.qh9 or opts.qm9: - if train: self.mol_strs = df["pyscf"].values[:-10] - else: self.mol_strs = df["pyscf"].values[-10:] + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] #print(df["pyscf"].) # todo: print smile strings self.num_epochs = num_epochs self.opts = opts self.validation = not train self.extrapolate = extrapolate - self.do_pyscf = self.validation or self.extrapolate self.benzene = [ ["C", ( 0.0000, 0.0000, 0.0000)], @@ -754,64 +665,34 @@ def __init__(self, opts, df=None, nao=294, train=True, num_epochs=10**9, extrapo ] if opts.benzene: self.mol_strs = [self.benzene] - - if opts.md17 > 0: - mol = {MD17_WATER: "water", MD17_ALDEHYDE: "malondialdehyde", MD17_ETHANOL: "ethanol", MD17_URACIL: "uracil"}[opts.md17] - mode = {True: "train", False: "val"}[train] - filename = "md17/%s_%s.pickle"%(mode, mol) - df = pd.read_pickle(filename) - - self.mol_strs = df["pyscf"].values.tolist() - N = int(np.sqrt(df["H"].values.tolist()[0].reshape(-1).size)) - self.H = [a.reshape(N, N) for a in df["H"].values.tolist()] - self.E = df["E"].values.tolist() - self.mol_strs = [eval(a) for a in self.mol_strs] - else: - self.H = [0 for _ in self.mol_strs] - self.E = [0 for _ in self.mol_strs] - - + if opts.waters: self.mol_strs = [self.waters] if opts.alanine: self.mol_strs = mol_str if train: self.bs = opts.bs else: self.bs = opts.val_bs - - def __len__(self): return len(self.mol_strs)*self.num_epochs + def __len__(self): + return len(self.mol_strs)*self.num_epochs def __getitem__(self, idx): return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ - wiggle_num=0, do_pyscf=self.do_pyscf, validation=False, \ - extrapolate=self.extrapolate, mol_idx=idx), self.H[idx%len(self.mol_strs)], self.E[idx%len(self.mol_strs)] + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) - val_qm9 = OnTheFlyQM9(opts, train=False, df=df) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) - ext_qm9 = OnTheFlyQM9(opts, extrapolate=True, df=df) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + val_qm9 = OnTheFlyQM9(opts, train=False) + ext_qm9 = OnTheFlyQM9(opts, extrapolate=True) + # parallel dataloader bug; precompute here is not slow but causes dataloader later to die. + # run once to quickly precompute. if opts.precompute: val_state = val_qm9[0] ext_state = ext_qm9[0] exit() - qm9 = OnTheFlyQM9(opts, train=True, df=df) - print("[%.4fs] initialized datasets. "%(time.time()-start_time) ) + qm9 = OnTheFlyQM9(opts, train=True) if opts.workers != 0: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) pbar = tqdm(train_dataloader) - print("[%.4fs] initialized dataloaders. "%(time.time()-start_time) ) - - if opts.test_dataloader: - - t0 = time.time() - for iteration, (state, H, E) in enumerate(pbar): - if iteration == 0: summary(state) - print(time.time()-t0) - t0 = time.time() - print(state.pad_sizes.reshape(1, -1)) - - exit() else: @@ -829,7 +710,6 @@ def __next__(self): return self.item vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) adam_state = adam.init(w) - print("[%.4fs] jitted vandg and valf."%(time.time()-start_time) ) if opts.resume: print("loading adam state") @@ -837,13 +717,11 @@ def __next__(self): return self.item print("done") w, adam_state = jax.device_put(w), jax.device_put(adam_state) - print("[%.4fs] jax.device_put(w,adam_state)."%(time.time()-start_time) ) @partial(jax.jit, backend=opts.backend) def update(w, adam_state, accumulated_grad): - if opts.grad_acc: accumulated_grad = jax.tree_map(lambda x: x / (opts.bs * opts.mol_repeats), accumulated_grad) - else: accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) + accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) updates, adam_state = adam.update(accumulated_grad, adam_state, w) w = optax.apply_updates(w, updates) return w, adam_state @@ -854,14 +732,10 @@ def update(w, adam_state, accumulated_grad): min_val, min_dm, mins, valid_str, step, val_state, ext_state = 0, 0, np.ones(opts.bs)*1e6, "", 0, None, None t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 - accumulated_grad = None paddings = [] states = [] - - - print("[%.4fs] first iteration."%(time.time()-start_time) ) - for iteration, (state, H, E) in enumerate(pbar): + for iteration, state in enumerate(pbar): if iteration == 0: summary(state) state = jax.device_put(state) @@ -877,68 +751,55 @@ def update(w, adam_state, accumulated_grad): states.append(state) if len(states) > opts.mol_repeats: states.pop(0) - if opts.shuffle: random.shuffle(states) # load_time, t0 = time.time()-t0, time.time() - + if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: + t0 = time.time() + try: + name = opts.name.replace("-", "_") + path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) + path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) + print("trying to checkpoint to %s and %s"%(path_model, path_adam)) + pickle.dump(jax.device_get(w), open(path_model, "wb")) + pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) + print("done!") + print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) + except: + print("fail!") + pass + print("tried saving model took %fs"%(time.time()-t0)) + save_time, t0 = time.time()-t0, time.time() + + + if len(states) < 50: print(len(states)) + + for j, state in enumerate(states): + print(". ", end="", flush=True) + if j == 0: _t0 =time.time() + (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) + print(",", end="", flush=True) + if j == 0: time_step1 = time.time()-_t0 - if len(states) < 50: print(len(states), opts.name) - - reps = 1 - if opts.md17 == 4: reps = 5 - if opts.md17 == 3: reps = 2 - if opts.md17 == 2: reps = 2 - - for _ in range(reps): - for j, state in enumerate(states): - print(". ", end="", flush=True) - if j == 0: _t0 =time.time() - (val, (vals, E_xc, density_matrix, _W, _H)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) - print(",", end="", flush=True) - if j == 0: time_step1 = time.time()-_t0 - - if opts.grad_acc == 0 or len(states) < opts.mol_repeats: - w, adam_state = update(w, adam_state, grad) - else: - accumulated_grad = grad if accumulated_grad is None else jax.tree_map(lambda x, y: x + y, accumulated_grad, grad) - - if (j+1) % opts.grad_acc == 0 and j > 0: # we assume opts.grad_acc divides opts.mol_repeats; prev was basically grad_acc=0 or grad_acc=mol_repeats, can now do hybrid. - w, adam_state = update(w, adam_state, grad) - accumulated_grad = None - print("#", end="", flush=True) - - - if opts.checkpoint != -1 and adam_state[1].count % opts.checkpoint == 0 and adam_state[1].count > 0: - t0 = time.time() - try: - name = opts.name.replace("-", "_") - path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) - path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) - print("trying to checkpoint to %s and %s"%(path_model, path_adam)) - pickle.dump(jax.device_get(w), open(path_model, "wb")) - pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) - print("done!") - print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) - except: - print("fail!") - pass - print("tried saving model took %fs"%(time.time()-t0)) - save_time, t0 = time.time()-t0, time.time() + # todo: have hyper parameter that accumulates gradient or takes step? + w, adam_state = update(w, adam_state, grad) # todo: rename global_batch_size = len(states)*opts.bs if opts.wandb: dct["global_batch_size"] = global_batch_size train_time, t0 = time.time()-t0, time.time() + + # plot grad norm + #if iteration % 10 == 0: + # for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) update_time, t0 = time.time()-t0, time.time() if not opts.nn: str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) else: - #print(vals[0], E) - pbar.set_description("train=%.4f"%(vals[0]*HARTREE_TO_EV) + "[eV] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) + pbar.set_description("train=".join(["%.2f"%i for i in vals[:1]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) if opts.wandb: dct["time_load"] = load_time @@ -946,66 +807,45 @@ def update(w, adam_state, accumulated_grad): dct["time_train"] = train_time dct["time_val"] = val_time plot_iteration = iteration % 10 == 0 - - dct["train_E"] = np.abs(E*HARTREE_TO_EV) - dct["train_E_pred"] = np.abs(vals[0]*HARTREE_TO_EV) + for i in range(0, 2): + if not opts.nn: + dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) + dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) + dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) + if plot_iteration: + dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) + dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) step = adam_state[1].count plot_time, t0 = time.time()-t0, time.time() - if opts.nn and (iteration < 250 or iteration % 10 == 0): - val_idx = 1 - if val_state is None: val_state, val_H, val_E = jax.device_put(val_qm9[val_idx]) # todo: cat 8 of these. - _, (valid_vals, _, vdensity_matrix, vW, _val_H) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) - - if opts.md17 > 0: - def get_H_from_dm(dm): - import pyscf - from pyscf import gto, dft - m = pyscf.gto.Mole(atom=val_qm9.mol_strs[val_idx], basis="def2-svp", unit="bohr") - m.build() - mf = dft.RKS(m) - mf.xc = 'B3LYP5' - mf.verbose = 0 - mf.diis_space = 8 - mf.conv_tol = 1e-13 - mf.grad_tol = 3.16e-5 - mf.grids.level = 3 - #mf.kernel() - h_core = mf.get_hcore() - S = mf.get_ovlp() - vxc = mf.get_veff(m, dm) - H = h_core + vxc - S = mf.get_ovlp() - return H, S - - matrix = np.array(vdensity_matrix[0]) - N = int(np.sqrt(matrix.size)) - _val_H, S = get_H_from_dm(matrix.reshape(N, N)) - - # compare eigenvalues - pred_vals = scipy.linalg.eigh(_val_H, S)[0] - label_vals = scipy.linalg.eigh(val_H, S)[0] - MAE_vals = np.mean(np.abs(pred_vals - label_vals)) - dct["val_eps"] = MAE_vals - + - lr = custom_schedule(step) - valid_str = "lr=%.3e"%lr + "val=%.4f [eV] "%(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV) + "mae_H=%.4f "%( - np.mean(np.abs(val_H/np.abs(val_H) - _val_H/np.abs(_val_H))) - ) - if opts.md17> 0:valid_str+= " eps=%.4f"%(MAE_vals) - valid_str += "val'=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + # TODO: Plot molecules and val/ext angles. + if opts.nn and (iteration < 250 or iteration % 10 == 0): - dct['val_E'] = np.abs(valid_vals[0]*HARTREE_TO_EV-val_E*HARTREE_TO_EV ) - dct['val_H_MAE'] = np.mean(np.abs(val_H - _val_H)) # perhaps sign doesn't matter? + if val_state is None: val_state = jax.device_put(val_qm9[0]) + _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) + if ext_state is None: ext_state = jax.device_put(ext_qm9[0]) + _, (ext_vals, _, edensity_matrix, eW) = valf(w, ext_state, opts.normal, opts.nn, cfg, opts) + lr = custom_schedule(step) + valid_str = "lr=%.3e"%lr + "val=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + valid_str += "ext=" + "".join(["%.4f "%(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" if opts.wandb: - for i in range(0, 3): + for i in range(0, opts.val_bs): dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) + dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) + dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) + + dct['ext_l%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) + dct['ext_E%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV) + dct['ext_pyscf%i'%i ] = np.abs(ext_state.pyscf_E[i]) + dct['img/ext_dm%i'%i] = wandb.Image(np.expand_dims(edensity_matrix[i], axis=-1)) + dct['img/ext_W%i'%i] = wandb.Image(np.expand_dims(eW[i], axis=-1)) dct["scheduled_lr"] = lr @@ -1154,6 +994,7 @@ def get_partition( coords_all = [] weights_all = [] + # [ ] consider another grid? for ia in range(mol.natm): coords, vol = atom_grids_tab[mol.atom_symbol(ia)] coords = coords + atom_coords[ia] # [ngrid, 3] @@ -1178,9 +1019,7 @@ def build(self, atom_coords, state=None) : mol = self.mol atom_grids_tab = self.gen_atomic_grids( - mol, self.atom_grid, self.radi_method, self.level, - self.prune, - #False, # WARNING: disabling self.prune; this makes sizes of C,N,O,F all a bit larger, but the same ; allow atom substitution + mol, self.atom_grid, self.radi_method, self.level, self.prune ) coords, weights = get_partition( @@ -1210,10 +1049,9 @@ def grids_from_pyscf_mol( def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): - do_print = False #t0 = time.time() - mol = build_mol(mol_str, opts.basis, unit="bohr") - if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts, unit="bohr") + mol = build_mol(mol_str, opts.basis) + if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) @@ -1221,8 +1059,6 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. from pyscf import dft - if do_print: print("grid", end="", flush=True) - #grids = pyscf.dft.gen_grid.Grids(mol) grids = DifferentiableGrids(mol) grids.level = opts.level @@ -1234,8 +1070,6 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. - if do_print: print("int1e", end="", flush=True) - # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. kinetic = mol.intor_symmetric('int1e_kin') # (N,N) nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) @@ -1260,14 +1094,10 @@ def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=T eri_threshold = 0 batches = 1 nipu = 1 - - # todo: rewrite int2e_sph to only recompute changing atomic orbitals (will be N times faster). - if do_print: print("int2e",end ="", flush=True) nonzero_distinct_ERI = mol.intor("int2e_sph", aosym="s8") #ERI = [nonzero_distinct_ERI, nonzero_indices] #ERI = ERI ERI = np.zeros(1) - if do_print: print(nonzero_distinct_ERI.shape, nonzero_distinct_ERI.nbytes/10**9) #ERI = mol.intor("int2e_sph") def e(x): return np.expand_dims(x, axis=0) @@ -1474,19 +1304,12 @@ def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> def pyscf_reference(mol_str, opts): from pyscf import __config__ __config__.dft_rks_RKS_grids_level = opts.level - mol = build_mol(mol_str, opts.basis, unit="bohr") + mol = build_mol(mol_str, opts.basis) mol.max_cycle = 50 mf = pyscf.scf.RKS(mol) - #mf.max_cycle = 50 - #mf.xc = "b3lyp5" - #mf.diis_space = 8 - mf.xc = 'B3LYP5' - mf.verbose = 0 # put this to 4 to check i set parameters correctly! + mf.max_cycle = 50 + mf.xc = "b3lyp5" mf.diis_space = 8 - # options from qh9 - mf.conv_tol=1e-13 - mf.grad_tol=3.16e-5 - mf.grids.level = 3 pyscf_energies = [] pyscf_hlgaps = [] lumo = mol.nelectron//2 @@ -1525,26 +1348,35 @@ def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, cosine_similarity = dot_products / (norm_X * norm_Y) print("Force cosine similarity:",cosine_similarity) -def build_mol(mol_str, basis_name, unit="bohr"): +def build_mol(mol_str, basis_name): mol = pyscf.gto.mole.Mole() - mol.build(atom=mol_str, unit=unit, basis=basis_name, spin=0, verbose=0) + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) return mol -def reference(mol_str, opts, unit="bohr"): +def reference(mol_str, opts): import pickle import hashlib if opts.skip: return np.zeros(1), np.zeros(1), np.zeros(1) - filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level) + unit).encode('utf-8')).hexdigest() + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() print(filename) if not os.path.exists(filename): pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) with open(filename, "wb") as file: - pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces, unit], file) + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) else: - pyscf_E, pyscf_hlgap, pyscf_forces, unit = pickle.load(open(filename, "rb")) + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) return pyscf_E, pyscf_hlgap, pyscf_forces +class HashableNamespace: + def __init__(self, namespace): + self.__dict__.update(namespace.__dict__) + + def __hash__(self): + # Convert the relevant attributes to a tuple for hashing + return hash(tuple(sorted(self.__dict__.items()))) + + if __name__ == "__main__": import os import argparse @@ -1556,18 +1388,11 @@ def reference(mol_str, opts, unit="bohr"): # GD options parser.add_argument('-backend', type=str, default="cpu") - parser.add_argument('-lr', type=float, default=5e-4) - parser.add_argument('-min_lr', type=float, default=1e-7) - parser.add_argument('-warmup_iters', type=float, default=1000) - parser.add_argument('-lr_decay', type=float, default=200000) - parser.add_argument('-ema', type=float, default=0.0) - + parser.add_argument('-lr', type=float, default=2.5e-4) parser.add_argument('-steps', type=int, default=100000) parser.add_argument('-bs', type=int, default=8) - parser.add_argument('-val_bs', type=int, default=3) + parser.add_argument('-val_bs', type=int, default=8) parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. - parser.add_argument('-grad_acc', type=int, default=0) # integer, deciding how many steps to accumulate. - parser.add_argument('-shuffle', action="store_true") # whether to to shuffle the window of states each step. # energy computation speedups parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) @@ -1582,16 +1407,12 @@ def reference(mol_str, opts, unit="bohr"): parser.add_argument('-skip', action="store_true", help="skip pyscf test case") # dataset - parser.add_argument('-nperturb', type=int, default=0, help="How many atoms to perturb (supports 1,2,3)") parser.add_argument('-qm9', action="store_true") - parser.add_argument('-md17', type=int, default=-1) - parser.add_argument('-qh9', action="store_true") parser.add_argument('-benzene', action="store_true") parser.add_argument('-hydrogens', action="store_true") parser.add_argument('-water', action="store_true") parser.add_argument('-waters', action="store_true") parser.add_argument('-alanine', action="store_true") - parser.add_argument('-do_print', action="store_true") # useful for debugging. parser.add_argument('-states', type=int, default=1) parser.add_argument('-workers', type=int, default=5) parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. @@ -1599,9 +1420,6 @@ def reference(mol_str, opts, unit="bohr"): parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") - parser.add_argument('-test_dataloader', action="store_true", help="no training, just test/loop through dataloader. ") - - # models parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") @@ -1611,19 +1429,12 @@ def reference(mol_str, opts, unit="bohr"): parser.add_argument('-medium', action="store_true") parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") - parser.add_argument('-largep', action="store_true") # large "plus" parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model opts = parser.parse_args() - if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True - assert opts.grad_acc == 0 or opts.mol_repeats % opts.grad_acc == 0, "mol_repeats needs to be a multiple of grad_acc (gradient accumulation)." - - class HashableNamespace: - def __init__(self, namespace): self.__dict__.update(namespace.__dict__) - def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) - opts = HashableNamespace(opts) + if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True args_dict = vars(opts) print(args_dict) @@ -1633,9 +1444,6 @@ def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) df = df[df["spin"] == 0] # only consider spin=0 mol_strs = df["pyscf"].values - if opts.qh9: - mol_strs = [] - # benzene if opts.benzene: mol_strs = [[ @@ -1658,7 +1466,7 @@ def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) ["H", ( 0.0000, 0.0000, 0.0000)], ["H", ( 1.4000, 0.0000, 0.0000)], ]] - if opts.md17 > 0 : + if opts.water: mol_strs = [[ ["O", ( 0.0000, 0.0000, 0.0000)], ["H", ( 0.0000, 1.4000, 0.0000)], @@ -1699,10 +1507,12 @@ def __hash__(self): return hash(tuple(sorted(self.__dict__.items()))) ["H", ( 6.360 , 8.648, -0.890)], ]] + # make opts hashable so that JAX will not complain about the static parameter that is passed as arg + opts = HashableNamespace(opts) nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) exit() pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) - print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) \ No newline at end of file + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py index 45e0e050..9cc2b9d0 100644 --- a/pyscf_ipu/direct/transformer.py +++ b/pyscf_ipu/direct/transformer.py @@ -58,7 +58,7 @@ def transformer_init( total_params += np.prod(params.embeddings.shape) print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) - rng, params.project_positions, shape = linear_init_uniform(rng, 123, d_model) + rng, params.project_positions, shape = linear_init_uniform(rng, 12, d_model) total_params += np.prod(shape) print("%26s %26s %26s"%("params.project_positions",shape, np.prod(shape))) @@ -95,32 +95,23 @@ def transformer_init( @partial(jax.jit, static_argnums=0) -def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray, L_inv): +def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray): """ cfg: Config, from transformer_init, holds hyperparameters params: Current transformer parameters, initialized in init x: 1D array of L integers, representing the input sequence output: L x n_vocab logits """ + L, = x.shape # x is just 1D. Vmap/pmap will handle batching + embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm - L, Dm = embeddings.shape - - # Roughly get f( {R@ri+t}_i ) = f( {r_i}_i ) - position = position - jnp.mean(position, axis=0).reshape(1, 3) # makes jnp.mean(position, axis=0) = [0,0,0] - cov = jnp.cov(position.T) - eigvects = jnp.linalg.eigh(cov)[1] - position = position @ eigvects # makes jnp.cov(positions.T)=jnp.eye(3) - - # Mix of sin/cos and 3d point cloud transformers. - #position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) - position = jnp.concatenate([position] + \ - [jnp.cos(position*f/20*2*np.pi) for f in range(20)] + \ - [jnp.sin(position*f/20*2*np.pi) for f in range(20)], - axis=1) #(N,3) -> (N,3+60+60) = (N, 123) + + all_pairs = jnp.linalg.norm(position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) + + # inspired by 3d point cloud transformers; + # nspired by andrew: use trigonometric functions as feature transformations + position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) positions = linear(params.project_positions, position) # L x Dm - del position - all_pairs = jnp.linalg.norm(positions.reshape(1, -1, Dm) - positions.reshape(-1, 1, Dm), axis=-1) - all_pairs = all_pairs / jnp.max(all_pairs) # Add (learned) positional encodings x = embeddings + positions # L x Dm @@ -137,13 +128,12 @@ def block(x, layer_num, layer): q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) k = jnp.transpose(k.reshape(L, nheads, Dm//nheads), (1, 0, 2)) v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm//nheads) + score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) - if True: # todo: why does this improve loss from ~1000 to ~300 first step (qm9). - score += H_core - #score += all_pairs # => NaNs for some reason - #score += L_inv - score += L_inv @ H_core @ L_inv.T + # do like graphformer and append position here? + #if layer_num < 6: # doesn't look like it helps + # score += H_core + # score += all_pairs attn = jax.nn.softmax(score , axis=1) x = x + (attn @ v).reshape(L, Dm) @@ -159,26 +149,16 @@ def block(x, layer_num, layer): # Residual connection x = x + t2 - return x + return x, score # Apply the transformer layers # todo: cut jit time by making this jax.lax.foriloop - for layer_num, layer in enumerate(params.layers[:-1]): - x = jax.checkpoint(block)(x, layer_num, layer) - - layer = params.layers[-1] - # Prediction is last attention (without nhead = 1), and q=k so score is symmetric! - nheads = 1 - t1 = vmap(standardize)(x) # L x Dm - t1 = elementwise_linear(layer.norm_self_attn, t1) # L x Dm - qkv = linear(layer.kqv, t1) - q,k,v = jnp.split(qkv, 3, axis=1) - q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - k = q - #v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) - score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm*nheads) # symmetric: initial loss goes from 1200 to 980 (qm9). + for layer_num, layer in enumerate(params.layers): + x, score = jax.checkpoint(block)(x, layer_num, layer) - M = score[0] + # todo: if this isn't symmetric eigh gives imaginary eigenvalues? (bad) + M = score[0] # take first attention head + #M = (M + M.T)/2 # make symmetric! return M import types @@ -294,7 +274,7 @@ def convert_to_float32(x): parser.add_argument('-large', action="store_true") parser.add_argument('-xlarge', action="store_true") opts = parser.parse_args() - + # initialize model # transformer tiny 5M d_model= 192 @@ -331,10 +311,10 @@ def convert_to_float32(x): extrapolate=False, mol_idx=0) summary(state) - output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0)), + output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), static_argnums=(0,), backend="cpu")(cfg, \ - params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) + params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) print(np.sum(output)) # 162.58726108305348 @@ -348,10 +328,10 @@ def convert_to_float32(x): new_params = pickle.load(open("checkpoints/example.pickle", "rb")) # check that output remains the same - new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0, 0), out_axes=(0)), + new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), static_argnums=(0,), backend="cpu")(cfg, \ - new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32), state.L_inv.astype(jnp.float32)) + new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) assert np.allclose(output, new_output) print("TEST CASE PASSED!")