diff --git a/pyscf_ipu/direct/alanine.pdb b/pyscf_ipu/direct/alanine.pdb new file mode 100644 index 00000000..8c9156ea --- /dev/null +++ b/pyscf_ipu/direct/alanine.pdb @@ -0,0 +1,22 @@ +ATOM 1 1HH3 ACE 1 2.000 1.000 -0.000 +ATOM 2 CH3 ACE 1 2.000 2.090 0.000 +ATOM 3 2HH3 ACE 1 1.486 2.454 0.890 +ATOM 4 3HH3 ACE 1 1.486 2.454 -0.890 +ATOM 5 C ACE 1 3.427 2.641 -0.000 +ATOM 6 O ACE 1 4.391 1.877 -0.000 +ATOM 7 N ALA 2 3.555 3.970 -0.000 +ATOM 8 H ALA 2 2.733 4.556 -0.000 +ATOM 9 CA ALA 2 4.853 4.614 -0.000 +ATOM 10 HA ALA 2 5.408 4.316 0.890 +ATOM 11 CB ALA 2 5.661 4.221 -1.232 +ATOM 12 1HB ALA 2 5.123 4.521 -2.131 +ATOM 13 2HB ALA 2 6.630 4.719 -1.206 +ATOM 14 3HB ALA 2 5.809 3.141 -1.241 +ATOM 15 C ALA 2 4.713 6.129 0.000 +ATOM 16 O ALA 2 3.601 6.653 0.000 +ATOM 17 N NME 3 5.846 6.835 0.000 +ATOM 18 H NME 3 6.737 6.359 -0.000 +ATOM 19 CH3 NME 3 5.846 8.284 0.000 +ATOM 20 1HH3 NME 3 4.819 8.648 0.000 +ATOM 21 2HH3 NME 3 6.360 8.648 0.890 +ATOM 22 3HH3 NME 3 6.360 8.648 -0.890 \ No newline at end of file diff --git a/pyscf_ipu/direct/alchemy/download.sh b/pyscf_ipu/direct/alchemy/download.sh new file mode 100755 index 00000000..2d9d4913 --- /dev/null +++ b/pyscf_ipu/direct/alchemy/download.sh @@ -0,0 +1,12 @@ +wget -O dev.zip https://alchemy.tencent.com/data/dev_v20190730.zip +wget -O valid.zip https://alchemy.tencent.com/data/valid_v20190730.zip +wget -O test.zip https://alchemy.tencent.com/data/test_v20190730.zip + +unzip dev.zip +unzip valid.zip +unzip test.zip + +wget -O alchemy.zip https://alchemy.tencent.com/data/alchemy-v20191129.zip +unzip alchemy.zip +mv Alchemy-v20191129/* . +rmdir Alchemy-v20191129 \ No newline at end of file diff --git a/pyscf_ipu/direct/alchemy/reproduce_pyscf.py b/pyscf_ipu/direct/alchemy/reproduce_pyscf.py new file mode 100644 index 00000000..262d1a37 --- /dev/null +++ b/pyscf_ipu/direct/alchemy/reproduce_pyscf.py @@ -0,0 +1,18 @@ +import pandas as pd +import pyscf +from pyscf import __config__ +__config__.dft_rks_RKS_grids_level = 3 +from pyscf import dft +import numpy as np + +df = pd.read_pickle("atom_9.pickle") +mol = pyscf.gto.Mole(atom=df["pyscf"].values[0], basis="6-31G(2df,p)", spin=0) +mol.build() +mf = pyscf.dft.RKS(mol) +mf.verbose = 4 +mf.xc = 'B3LYP5' # pyscf changed b3lyp from vwn5 to vwn3 to agree with gaussian. +print(mf.kernel()) +print(df["energy"].values[0]) +print(df["homo"].values[0]) +print(df["lumo"].values[0]) +print(df["gap"].values[0]) \ No newline at end of file diff --git a/pyscf_ipu/direct/alchemy/to_pickle.py b/pyscf_ipu/direct/alchemy/to_pickle.py new file mode 100644 index 00000000..9735f299 --- /dev/null +++ b/pyscf_ipu/direct/alchemy/to_pickle.py @@ -0,0 +1,81 @@ +import pandas as pd +import os +from rdkit import Chem +import numpy as np +import pyscf +from natsort import natsorted +from tqdm import tqdm + +# we test loading by reproducing labels with pyscf. +# (instead of checking e.g. np_to_sdf) + +def sdf_to_np(filename): + s = open('%s'%filename, 'r').read() + lines = s.split('V2000')[1].split('\n')[1:-1] + lines = [[a for a in line.split(' ') if a != '' ][:4] for line in lines if len(line)>35] + lines = [[line[3], (float(line[0]), float(line[1]), float(line[2]))] for line in lines ] + atom_str = [line[0] for line in lines] + atom_pos = np.concatenate([np.array(line[1:]).reshape(1, -1) for line in lines ] ) + return atom_str, atom_pos + +def np_to_pyscf(str, xyz): + atom_list = [] + + for i, atom_type in enumerate(str): + x, y, z = xyz[i] + atom_list.append([atom_type, (x, y, z)]) + + return atom_list + +def spin(pyscf_format): + try: + mol = pyscf.gto.Mole(atom=pyscf_format, basis="6-31g(2df,p)") + mol.build() + return 0 + except: + mol = pyscf.gto.Mole(atom=pyscf_format, basis="6-31g(2df,p)", spin=1) + mol.build() + return 1 + +def nao(pyscf_format, spin): + mol = pyscf.gto.Mole(atom=pyscf_format, basis="6-31g(2df,p)", spin=spin) + mol.build() + return mol.nao_nr() + +# load all labels in the final 200k version +df = pd.read_csv("final_version.csv") + +# add info on train/test/val split +train = pd.read_csv("dev/dev_target.csv") +valid = pd.read_csv("valid/valid_target.csv") +test = pd.DataFrame({"gdb_idx": os.listdir("test/sdf/atom_11") + os.listdir("test/sdf/atom_12")}) +df["train"] = df["gdb_idx"].isin(train["gdb_idx"]) +df["test"] = df["gdb_idx"].isin(test["gdb_idx"]) +df["valid"] = df["gdb_idx"].isin(valid["gdb_idx"]) + +# alchemy computes u0 = results['E_0K' ] = (E0 + ZPE, 'Eh'), so need to subtract zpve +# https://github.com/tencent-alchemy/alchemy-pyscf/blob/fa4f7ff46be308302ba1e95754701142b6c4bf7f/alchemypyscf/thermo.py#L215 +df["energy"] = df["U0\n(Ha, internal energy at 0 K)"] - df["zpve\n(Ha, zero point vibrational energy)"] + +for folder in ["atom_9", "atom_10", "atom_11", "atom_12"]: + files = natsorted(os.listdir(folder)) + + strs, xyzs, pyscfs, gdb_idxs, naos, spins = [], [], [], [], [], [] + + for f in tqdm(files): + try: + str, xyz = sdf_to_np("%s/%s"%(folder, f)) + strs.append(str) + xyzs.append(xyz) + pyscfs.append( np_to_pyscf(str, xyz) ) + gdb_idxs.append(int(f.replace(".sdf", ""))) + spins.append(spin(pyscfs[-1])) + naos.append(-1)#nao(pyscfs[-1], spins[-1])) + except: + print("broke %s"%f) + + df2 = pd.DataFrame({"gdb_idx": gdb_idxs, "pyscf": pyscfs, "str": strs, "xyz": xyzs, "nao": naos, "spin": spins}) + merged = pd.merge(df, df2, on="gdb_idx", how="inner") + + merged.to_pickle("%s.pickle"%folder) + break \ No newline at end of file diff --git a/pyscf_ipu/direct/another_plot.py b/pyscf_ipu/direct/another_plot.py new file mode 100644 index 00000000..0b73ffe4 --- /dev/null +++ b/pyscf_ipu/direct/another_plot.py @@ -0,0 +1,79 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt + + +ml_file = "heatmap_data_009.pkl" +pyscf_file = "heatmap_pyscf_009.pkl" +# Load data from the pickle file +with open(ml_file, 'rb') as file: + data_list = pickle.load(file) + +with open(pyscf_file, 'rb') as file: + pyscf_list = pickle.load(file) + +# Extract phi, psi, and values from the loaded data +phi_values, psi_values, heatmap_val = zip(*data_list) + +# Extract phi, psi, and values from the loaded data +phi_values_p, psi_values_p, heatmap_pyscf = zip(*pyscf_list) + +matrix_size = int(len(data_list) ** 0.5) + +heatmap_val = np.array(heatmap_val).reshape(matrix_size, matrix_size) +heatmap_pyscf = np.array(heatmap_pyscf).reshape(matrix_size, matrix_size) + +# valid_E = NN(molecule) \approx E +# state.pyscf_E = DFT(molecule) = E +# state.valid_l = | NN(molecule) - DFT(molecule) | +# +heatmap_pyscf = -heatmap_pyscf + +phi_coordinates, psi_coordinates = np.meshgrid(np.linspace(min(phi_values), max(phi_values), matrix_size), + np.linspace(min(psi_values), max(psi_values), matrix_size)) + +fig, ax = plt.subplots(2,3, figsize=(10, 8)) +# im = ax[0,0].imshow( heatmap_val ) +im = ax[0,0].imshow(heatmap_val, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +# ax[0,0].set_xlim(phi_values) +# ax[0,0].set_ylim(psi_values) +im2 = ax[0,1].imshow( heatmap_pyscf, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +diff = ax[0,2].imshow( np.abs(heatmap_val - heatmap_pyscf), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +log = ax[1,0].imshow( np.log(np.abs(heatmap_val )), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +log2 = ax[1,1].imshow( np.log(np.abs(heatmap_pyscf )), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +difflog = ax[1,2].imshow( np.log(np.abs((heatmap_val - heatmap_pyscf))), cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) + +for i in range(3): + for j in range(2): + ax[j, i].set_xticks(np.arange(phi_values[0], phi_values[-1], 45)) + ax[j, i].set_yticks(np.arange(psi_values[0], psi_values[-1], 45)) + # ax[j, i].set_xlim([phi_values[0], phi_values[-1]]) + # ax[j, i].set_ylim([psi_values[0], psi_values[-1]]) + ax[j, i].set_xlabel("phi [deg]") + ax[j, i].set_ylabel("psi [deg]") + +# orient = 'vertical' +orient = 'horizontal' +cbar = fig.colorbar(im, ax=ax[0, 0], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(im2, ax=ax[0, 1], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(diff, ax=ax[0, 2], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(log, ax=ax[1, 0], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(log2, ax=ax[1, 1], orientation=orient, fraction=0.05, pad=0.28) +cbar = fig.colorbar(difflog, ax=ax[1, 2], orientation=orient, fraction=0.05, pad=0.28) + +# for a in ax.reshape(-1): a.axis("off") +ax[0,0].set_title("NN Energy") +ax[0,1].set_title("PySCF Energy") +ax[0,2].set_title("|NN-PySCF| Energy") + +ax[1,0].set_title("NN log(|Energy|)") +ax[1,1].set_title("PySCF log(|Energy|)") +ax[1,2].set_title("|NN-PySCF| log(|Energy|)") +# ax[0,0].set_ylabel("Energy") # may fail with axis("off") +# ax[1,0].set_ylabel("log(|Energy|)") # may fail with axis("off") +plt.tight_layout() + +# Save the plot to a PNG file +plt.savefig("poc.png") \ No newline at end of file diff --git a/pyscf_ipu/direct/exchange_correlation/LICENSE b/pyscf_ipu/direct/exchange_correlation/LICENSE new file mode 100644 index 00000000..fa0086a9 --- /dev/null +++ b/pyscf_ipu/direct/exchange_correlation/LICENSE @@ -0,0 +1,373 @@ +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. \ No newline at end of file diff --git a/pyscf_ipu/direct/exchange_correlation/b3lyp.py b/pyscf_ipu/direct/exchange_correlation/b3lyp.py new file mode 100644 index 00000000..77a3ad01 --- /dev/null +++ b/pyscf_ipu/direct/exchange_correlation/b3lyp.py @@ -0,0 +1,141 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import jax.numpy as jnp +import jax +import numpy as np + +def __lyp(n, gnn): + # precompute + A = 0.04918 + B = 0.132 + C = 0.2533 + Dd = 0.349 + CF = 0.3 * (3.0 * np.pi * np.pi) ** (2.0 / 3.0) + c0 = 2.0 ** (11.0 / 3.0) * (1/2)**(8/3) + c1 = (1/3 + 1/8)*4 + + # actual compute + log_n = jnp.log(n) + icbrtn = jnp.exp(log_n * (-1.0 / 3.0) ) + + P = 1.0 / (1.0 + Dd * icbrtn) + omega = jnp.exp(-C * icbrtn) * P + delta = icbrtn * (C + Dd * P) + + n_five_three = jnp.exp(log_n*(-5/3)) + + result = -A * ( + n * P + + B + * omega + * 1/ 4 *( + 2 * CF * n * c0+ + gnn * (60 - 14.0 * delta) /36 * n_five_three + - gnn *c1 * n_five_three + ) + ) + + return result + +def __vwn(n): + # Precompute stuff in np.float64 + p = np.array( [-0.10498, 0.0621813817393097900698817274255, 3.72744, 12.9352]) + f = p[0] * p[2] / (p[0] * p[0] + p[0] * p[2] + p[3]) - 1.0 + f_inv_p1 = 1/f+1 + f_2 = f * 0.5 + sqrt = np.sqrt(4.0 * p[3] - p[2] * p[2]) + precompute = p[2] * ( 1.0 / sqrt + - p[0] + / ( + (p[0] * p[0] + p[0] * p[2] + p[3]) + * sqrt + / (p[2] + 2.0 * p[0]) + ) + ) + log_s_c = np.log( 3.0 /(4*np.pi) ) / 6 + + # Below cast to same dtype as input (allow easier comparison between f32/f64). + dtype = n.dtype + p = p.astype(dtype) + f = f.astype(dtype) + f_inv_p1 = (f_inv_p1).astype(dtype) + f_2 = f_2.astype(dtype) + sqrt = sqrt.astype(dtype) + precompute = precompute.astype(dtype) + log_s_c =log_s_c.astype(dtype) + + # compute stuff that depends on n + log_s = - jnp.log(n) / 6 + log_s_c + s_2 = jnp.exp( log_s *2) + s = jnp.exp( log_s ) + z = sqrt / (2.0 * s + p[2]) + + result = n * p[1] * ( + log_s + #+ f * jnp.log( jnp.sqrt( s_2 + p[2] * s + p[3] ) / (s-p[0])**(1/f+1) ) # problem with float, 1/f+1 was done in np which automatically sticks to float64 + + f * jnp.log( jnp.sqrt( s_2 + p[2] * s + p[3] ) / (s-p[0])**(f_inv_p1) ) + + precompute * jnp.arctan(z) + + ) + + return result + +def __b88(a, gaa): + # precompute + c1 = (4.0 / 3.0) + c2 = (-8.0 / 3.0) + c3 = (-3.0 / 4.0) * (6.0 / np.pi) ** (1.0 / 3.0) * 2 + d = 0.0042 + d2 = d * 2. + d12 = d *12. + + # actual compute + log_a = jnp.log(a/2) + na43 = jnp.exp(log_a * c1) + chi2 = gaa / 4* jnp.exp(log_a * c2 ) + chi = jnp.exp(jnp.log( chi2 ) / 2 ) + b88 = -(d * na43 * chi2) / (1.0 + 6*d * chi * jnp.arcsinh(chi)) *2 + slaterx_a = c3 * na43 + return slaterx_a + b88 + +def __lda(rho): return -jnp.exp(1/3*jnp.log(rho) - 0.30305460484554375) + +CLIP_RHO_MIN = 1e-10 +CLIP_RHO_MAX = 1e15 + +def _b3lyp(rho, EPSILON_B3LYP=0): + rho0 = jnp.clip(rho[0], CLIP_RHO_MIN, CLIP_RHO_MAX) + norms = jnp.linalg.norm(rho[1:4]*2+CLIP_RHO_MIN, axis=0).T**2+EPSILON_B3LYP + return __lda(rho0)*0.08 + (__vwn(rho0)*0.19 + __b88(rho0, norms)*0.72 + __lyp(rho0, norms)*0.81) / rho0 + +def b3lyp(rho0, norms, EPSILON_B3LYP=0): + return __lda(rho0)*0.08 + (__vwn(rho0)*0.19 + __b88(rho0, norms)*0.72 + __lyp(rho0, norms)*0.81) / rho0 + + +def vxc_b3lyp(rho, EPSILON_B3LYP=0): + rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) + + rho0 = rho.T[:, 0] + norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP + + def lda(rho0): return jax.vmap(jax.value_and_grad(lambda x: __lda(x)*0.08)) (rho0) + def vwn(rho0): return jax.vmap(jax.value_and_grad(lambda x: __vwn(x)*0.19)) (rho0) + + # disabled gradient checkpointing + #def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__b88)(rho0, norm)*0.72, (0, 1))) (rho0, norms) + #def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__lyp)(rho0, norm)*0.810, (0, 1))) (rho0, norms) + + def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __b88(rho0, norm)*0.72, (0,1)))(rho0, norms) + def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __lyp(rho0, norm)*0.810, (0,1)))(rho0, norms) + + e_xc_lda, v_rho_lda = jax.jit(lda)(rho0) + e_xc_vwn, v_rho_vwn = jax.jit(vwn)(rho0) + e_xc_b88, (v_rho_b88, v_norm_b88) = jax.jit(b88)(rho0, norms) + e_xc_lyp, (v_rho_lyp, v_norm_lyp) = jax.jit(lyp)(rho0, norms) + + e_xc = e_xc_lda + (e_xc_vwn + e_xc_b88 + e_xc_lyp) / rho0 + v_xc_rho = v_rho_lda*4*rho0 + v_rho_vwn + v_rho_b88 + v_rho_lyp + v_xc_norms = v_norm_b88 + v_norm_lyp + + return e_xc, v_xc_rho, v_xc_norms + + diff --git a/pyscf_ipu/direct/inference_heatmap_plot.py b/pyscf_ipu/direct/inference_heatmap_plot.py new file mode 100644 index 00000000..cd86d5b5 --- /dev/null +++ b/pyscf_ipu/direct/inference_heatmap_plot.py @@ -0,0 +1,303 @@ +import pickle +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-basis', type=str, default="sto3g") +parser.add_argument('-level', type=int, default=0) + +# GD options +parser.add_argument('-backend', type=str, default="cpu") +parser.add_argument('-lr', type=float, default=2.5e-4) +parser.add_argument('-steps', type=int, default=100000) +parser.add_argument('-bs', type=int, default=8) +parser.add_argument('-val_bs', type=int, default=8) +parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + +# energy computation speedups +parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) +parser.add_argument('-xc_f32', action="store_true") +parser.add_argument('-eri_f32', action="store_true") +parser.add_argument('-eri_bs', type=int, default=8) + +parser.add_argument('-normal', action="store_true") +parser.add_argument('-wandb', action="store_true") +parser.add_argument('-prof', action="store_true") +parser.add_argument('-visualize', action="store_true") +parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + +# dataset +parser.add_argument('-qm9', action="store_true") +parser.add_argument('-benzene', action="store_true") +parser.add_argument('-hydrogens', action="store_true") +parser.add_argument('-water', action="store_true") +parser.add_argument('-waters', action="store_true") +parser.add_argument('-alanine', action="store_true") +parser.add_argument('-states', type=int, default=1) +parser.add_argument('-workers', type=int, default=5) +parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase +parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") +parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") +parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + +# models +parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") +parser.add_argument('-tiny', action="store_true") +parser.add_argument('-small', action="store_true") +parser.add_argument('-base', action="store_true") +parser.add_argument('-medium', action="store_true") +parser.add_argument('-large', action="store_true") +parser.add_argument('-xlarge', action="store_true") + +parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model +parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + +# inference heatmap plot args +parser.add_argument("-heatmap_step", type=int, default=10) +parser.add_argument("-plot_range", type=int, default=360) +opts = parser.parse_args() + +assert opts.val_bs * opts.heatmap_step == opts.plot_range, "[Temporary dependency] Try adjusting VAL_BS and HEATMAP_STEP so that their product is equal to PLOT_RANGE (by default 360)" + +if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + +if opts.alanine: + mol_str = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + +B, BxNxN, BxNxK = None, None, None +cfg = None +from train import dm_energy + +from transformer import transformer_init +from train import nao +# global cfg +'''Model ViT model embedding #heads #layers #params training throughput +dimension resolution (im/sec) +DeiT-Ti N/A 192 3 12 5M 224 2536 +DeiT-S N/A 384 6 12 22M 224 940 +DeiT-B ViT-B 768 12 12 86M 224 292 +Parameters Layers dmodel +117M 12 768 +345M 24 1024 +762M 36 1280 +1542M 48 1600 +''' +if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 +if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 +if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 +if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 +if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 +if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + +if opts.nn: + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + +# vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) +valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + +from train import batched_state +from torch.utils.data import DataLoader, Dataset +class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False, init_phi_psi = None): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + import pandas as pd + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + self.init_phi_psi = init_phi_psi + + # self.benzene = [ + # ["C", ( 0.0000, 0.0000, 0.0000)], + # ["C", ( 1.4000, 0.0000, 0.0000)], + # ["C", ( 2.1000, 1.2124, 0.0000)], + # ["C", ( 1.4000, 2.4249, 0.0000)], + # ["C", ( 0.0000, 2.4249, 0.0000)], + # ["C", (-0.7000, 1.2124, 0.0000)], + # ["H", (-0.5500, -0.9526, 0.0000)], + # ["H", (-0.5500, 3.3775, 0.0000)], + # ["H", ( 1.9500, -0.9526, 0.0000)], + # ["H", (-1.8000, 1.2124, 0.0000)], + # ["H", ( 3.2000, 1.2124, 0.0000)], + # ["H", ( 1.9500, 3.3775, 0.0000)] + # ] + # self.waters = [ + # ["O", (-1.464, 0.099, 0.300)], + # ["H", (-1.956, 0.624, -0.340)], + # ["H", (-1.797, -0.799, 0.206)], + # ["O", ( 1.369, 0.146, -0.395)], + # ["H", ( 1.894, 0.486, 0.335)], + # ["H", ( 0.451, 0.165, -0.083)] + # ] + + # if opts.benzene: self.mol_strs = [self.benzene] + # if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx, init_phi_psi = self.init_phi_psi, inference=True, inference_psi_step=opts.heatmap_step) + + +print("loading checkpoint") +weights = pickle.load(open("%s_model.pickle"%opts.resume, "rb")) +print("done loading. ") + +# print("loading adam state") +# adam_state = pickle.load(open("%s_adam_state.pickle"%opts.resume, "rb")) +# print("done") + +# weights, adam_state = jax.device_put(weights), jax.device_put(adam_state) +weights = jax.device_put(weights) + +from train import HashableNamespace + +# make `opts` hashable so that JAX will not complain about the static parameter that is passed as arg +opts = HashableNamespace(opts) + +data = [] +pyscf = [] +# data.append((1,1,344)) +# data.append((2,4,323)) +# data.append((3,3,334)) +# data.append((4,2,331)) + +for phi in range(0, opts.plot_range, opts.heatmap_step): + for psi in range(0, opts.plot_range, opts.val_bs * opts.heatmap_step): + val_qm9 = OnTheFlyQM9(opts, train=False, init_phi_psi=(phi, psi)) + val_state = jax.device_put(val_qm9[0]) + # print("\n^^^^^^^^^^^\nJUST VAL QM9 [0]:", val_qm9[0]) + # print("WHOLE VAL QM9:", val_qm9) + print("VAL_QM9[0].pyscf_E:", val_qm9[0].pyscf_E) + _, (valid_vals, _, vdensity_matrix, vW) = valf(weights, val_state, opts.normal, opts.nn, cfg, opts) + + valid_l = np.abs(valid_vals*HARTREE_TO_EV-val_state.pyscf_E) + valid_E = np.abs(valid_vals*HARTREE_TO_EV) + + print("valid_l: ", valid_l, "\nvalid_E: ", valid_E, "\nphi ", phi, " psi ", psi) + + for i in range(0, opts.val_bs): + data.append((phi, psi + i * opts.heatmap_step, valid_E[i])) + pyscf.append((phi, psi + i * opts.heatmap_step, val_state.pyscf_E[i].item())) + + # data.append((phi, psi, valid_E[0])) + +#data = np.log(np.abs(data)) +import matplotlib.pyplot as plt +from scipy.interpolate import griddata +# Extract phi, psi, and values from the data +phi_values, psi_values, heatmap_values = zip(*data) + +# Define a grid +phi_grid, psi_grid = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_grid, psi_grid), method='cubic', fill_value=0) + + +# Create a filled contour plot +plt.contourf(psi_grid, phi_grid, heatmap_interpolated, cmap='viridis', levels=100) +plt.colorbar(label='Intensity') + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap with Interpolation') + +# Save the plot to a PNG file +plt.savefig('heatmap_plot.png') + +# Show the plot +plt.show() + +import pickle + +print("DATA ML", data) +print("DATA PYSCF", pyscf) +# Save data to a pickle file +with open('heatmap_data.pkl', 'wb') as file: + pickle.dump(data, file) + + +# Save pyscf to a pickle file +with open('heatmap_pyscf.pkl', 'wb') as file: + pickle.dump(pyscf, file) \ No newline at end of file diff --git a/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py b/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py new file mode 100644 index 00000000..6bcbfa9d --- /dev/null +++ b/pyscf_ipu/direct/inference_heatmap_plot_small_bs.py @@ -0,0 +1,311 @@ +import pickle +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-basis', type=str, default="sto3g") +parser.add_argument('-level', type=int, default=0) + +# GD options +parser.add_argument('-backend', type=str, default="cpu") +parser.add_argument('-lr', type=float, default=2.5e-4) +parser.add_argument('-steps', type=int, default=100000) +parser.add_argument('-bs', type=int, default=8) +parser.add_argument('-val_bs', type=int, default=8) +parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + +# energy computation speedups +parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) +parser.add_argument('-xc_f32', action="store_true") +parser.add_argument('-eri_f32', action="store_true") +parser.add_argument('-eri_bs', type=int, default=8) + +parser.add_argument('-normal', action="store_true") +parser.add_argument('-wandb', action="store_true") +parser.add_argument('-prof', action="store_true") +parser.add_argument('-visualize', action="store_true") +parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + +# dataset +parser.add_argument('-qm9', action="store_true") +parser.add_argument('-benzene', action="store_true") +parser.add_argument('-hydrogens', action="store_true") +parser.add_argument('-water', action="store_true") +parser.add_argument('-waters', action="store_true") +parser.add_argument('-alanine', action="store_true") +parser.add_argument('-states', type=int, default=1) +parser.add_argument('-workers', type=int, default=5) +parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase +parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") +parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") +parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + +# models +parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") +parser.add_argument('-tiny', action="store_true") +parser.add_argument('-small', action="store_true") +parser.add_argument('-base', action="store_true") +parser.add_argument('-medium', action="store_true") +parser.add_argument('-large', action="store_true") +parser.add_argument('-xlarge', action="store_true") + +parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model +parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + +# inference heatmap plot args +parser.add_argument("-heatmap_step", type=int, default=10) +parser.add_argument("-plot_range", type=int, default=360) +opts = parser.parse_args() + +# assert opts.val_bs * opts.heatmap_step == opts.plot_range, "[Temporary dependency] Try adjusting VAL_BS and HEATMAP_STEP so that their product is equal to PLOT_RANGE (by default 360)" +assert (opts.plot_range % (opts.val_bs * opts.heatmap_step)) == 0, "batch * step will not fit within the range with integer number of subranges" +if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + +if opts.alanine: + mol_str = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + +B, BxNxN, BxNxK = None, None, None +cfg = None +from train import dm_energy + +from transformer import transformer_init +from train import nao +# global cfg +'''Model ViT model embedding #heads #layers #params training throughput +dimension resolution (im/sec) +DeiT-Ti N/A 192 3 12 5M 224 2536 +DeiT-S N/A 384 6 12 22M 224 940 +DeiT-B ViT-B 768 12 12 86M 224 292 +Parameters Layers dmodel +117M 12 768 +345M 24 1024 +762M 36 1280 +1542M 48 1600 +''' +if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 +if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 +if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 +if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 +if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 +if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + +if opts.nn: + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + +# vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn')) +valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + +from train import batched_state +from torch.utils.data import DataLoader, Dataset +class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False, init_phi_psi = None): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + import pandas as pd + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + self.init_phi_psi = init_phi_psi + + # self.benzene = [ + # ["C", ( 0.0000, 0.0000, 0.0000)], + # ["C", ( 1.4000, 0.0000, 0.0000)], + # ["C", ( 2.1000, 1.2124, 0.0000)], + # ["C", ( 1.4000, 2.4249, 0.0000)], + # ["C", ( 0.0000, 2.4249, 0.0000)], + # ["C", (-0.7000, 1.2124, 0.0000)], + # ["H", (-0.5500, -0.9526, 0.0000)], + # ["H", (-0.5500, 3.3775, 0.0000)], + # ["H", ( 1.9500, -0.9526, 0.0000)], + # ["H", (-1.8000, 1.2124, 0.0000)], + # ["H", ( 3.2000, 1.2124, 0.0000)], + # ["H", ( 1.9500, 3.3775, 0.0000)] + # ] + # self.waters = [ + # ["O", (-1.464, 0.099, 0.300)], + # ["H", (-1.956, 0.624, -0.340)], + # ["H", (-1.797, -0.799, 0.206)], + # ["O", ( 1.369, 0.146, -0.395)], + # ["H", ( 1.894, 0.486, 0.335)], + # ["H", ( 0.451, 0.165, -0.083)] + # ] + + # if opts.benzene: self.mol_strs = [self.benzene] + # if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx, init_phi_psi = self.init_phi_psi, inference=True, inference_psi_step=opts.heatmap_step) + + +print("loading checkpoint") +weights = pickle.load(open("%s_model.pickle"%opts.resume, "rb")) +print("done loading. ") + +# print("loading adam state") +# adam_state = pickle.load(open("%s_adam_state.pickle"%opts.resume, "rb")) +# print("done") + +# weights, adam_state = jax.device_put(weights), jax.device_put(adam_state) +weights = jax.device_put(weights) + +from train import HashableNamespace + +# make `opts` hashable so that JAX will not complain about the static parameter that is passed as arg +opts = HashableNamespace(opts) + +data = [] +pyscf = [] +# data.append((1,1,344)) +# data.append((2,4,323)) +# data.append((3,3,334)) +# data.append((4,2,331)) + +valid_E = None +val_state = None +for phi in range(0, opts.plot_range, opts.heatmap_step): + # psi_start = 0 + # psi_end = psi_start + opts.val_bs * opts.heatmap_step + # while psi_end <= opts.plot_range: + # for psi in range(psi_start, psi_end, opts.heatmap_step): + for psi in range(0, opts.plot_range, opts.val_bs * opts.heatmap_step): + # print(psi, psi_start, psi_end, "<<<<<<<<<<<<<<<<<<") + val_qm9 = OnTheFlyQM9(opts, train=False, init_phi_psi=(phi, psi)) + val_state = jax.device_put(val_qm9[0]) + # print("\n^^^^^^^^^^^\nJUST VAL QM9 [0]:", val_qm9[0]) + # print("WHOLE VAL QM9:", val_qm9) + print("VAL_QM9[0].pyscf_E:", val_qm9[0].pyscf_E) + _, (valid_vals, _, vdensity_matrix, vW) = valf(weights, val_state, opts.normal, opts.nn, cfg, opts) + + valid_l = np.abs(valid_vals*HARTREE_TO_EV-val_state.pyscf_E) + valid_E = np.abs(valid_vals*HARTREE_TO_EV) + + print("valid_l: ", valid_l, "\nvalid_E: ", valid_E, "\nphi ", phi, " psi ", psi) + + for i in range(0, opts.val_bs): + data.append((phi, psi + i * opts.heatmap_step, valid_E[i])) + pyscf.append((phi, psi + i * opts.heatmap_step, val_state.pyscf_E[i].item())) + # psi_start = 0 + psi_end + # psi_end += opts.val_bs * opts.heatmap_step + # data.append((phi, psi, valid_E[0])) + +#data = np.log(np.abs(data)) +import matplotlib.pyplot as plt +from scipy.interpolate import griddata +# Extract phi, psi, and values from the data +phi_values, psi_values, heatmap_values = zip(*data) + +# Define a grid +phi_grid, psi_grid = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_grid, psi_grid), method='cubic', fill_value=0) + + +# Create a filled contour plot +plt.contourf(psi_grid, phi_grid, heatmap_interpolated, cmap='viridis', levels=100) +plt.colorbar(label='Intensity') + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap with Interpolation') + +# Save the plot to a PNG file +plt.savefig('heatmap_plot.png') + +# Show the plot +plt.show() + +import pickle + +print("DATA ML", data) +print("DATA PYSCF", pyscf) +# Save data to a pickle file +with open('heatmap_data_bs2.pkl', 'wb') as file: + pickle.dump(data, file) + + +# Save pyscf to a pickle file +with open('heatmap_pyscf_bs2.pkl', 'wb') as file: + pickle.dump(pyscf, file) \ No newline at end of file diff --git a/pyscf_ipu/direct/plot.py b/pyscf_ipu/direct/plot.py new file mode 100644 index 00000000..430dad3a --- /dev/null +++ b/pyscf_ipu/direct/plot.py @@ -0,0 +1,21 @@ +import wandb +from rdkit import Chem +import rdkit +import rdkit.Chem +import rdkit.Chem.AllChem +from rdkit.Geometry.rdGeometry import Point3D +from rdkit.Chem import AllChem +import numpy as np + +def create_rdkit_mol(atom_types, atom_positions): + mol = Chem.RWMol() + for atom_type in atom_types: + atom = Chem.Atom(atom_type) + mol.AddAtom(atom) + conf = Chem.Conformer(len(atom_types)) + for i, pos in enumerate(atom_positions): + if isinstance(pos, np.ndarray): pos = pos.tolist() + point = Point3D(*pos) + conf.SetAtomPosition(i, point) + mol.AddConformer(conf) + return wandb.Molecule.from_rdkit(mol, convert_to_3d_and_optimize=False) diff --git a/pyscf_ipu/direct/plot_heatmap_for_paper.py b/pyscf_ipu/direct/plot_heatmap_for_paper.py new file mode 100644 index 00000000..58fab939 --- /dev/null +++ b/pyscf_ipu/direct/plot_heatmap_for_paper.py @@ -0,0 +1,44 @@ +import pickle +import numpy as np +import matplotlib.pyplot as plt +from scipy.interpolate import griddata + +import argparse +parser = argparse.ArgumentParser() +parser.add_argument('-data_file', type=str) +parser.add_argument('-output_name', type=str, default="default_output.png") +parser.add_argument('-log', type=bool, default=False) +opts = parser.parse_args() + +# Load data from the pickle file +with open(opts.data_file, 'rb') as file: + data_list = pickle.load(file) + + +# Extract phi, psi, and values from the loaded data +phi_values, psi_values, heatmap_values = zip(*data_list) + +if opts.log: + heatmap_values = np.log(np.abs(heatmap_values - np.mean(heatmap_values))) + +print(heatmap_values) +# Create a meshgrid of phi and psi coordinates +phi_coordinates, psi_coordinates = np.meshgrid(np.linspace(min(phi_values), max(phi_values), 100), + np.linspace(min(psi_values), max(psi_values), 100)) + +# Interpolate values on the grid +heatmap_interpolated = griddata((phi_values, psi_values), heatmap_values, (phi_coordinates, psi_coordinates), method='cubic', fill_value=0) + +# Display the 2D matrix as an image +plt.imshow(heatmap_interpolated, cmap='viridis', origin='lower', extent=[min(psi_values), max(psi_values), min(phi_values), max(phi_values)]) +plt.colorbar(label='Intensity') # Add colorbar with label + +# Set axis labels and title +plt.xlabel('Psi Angle') +plt.ylabel('Phi Angle') +plt.title('2D Heatmap from Pickle File') + +# Save the plot to a PNG file +plt.savefig(opts.output_name) + +# Show the plot diff --git a/pyscf_ipu/direct/qm9/download.sh b/pyscf_ipu/direct/qm9/download.sh new file mode 100755 index 00000000..7332f63e --- /dev/null +++ b/pyscf_ipu/direct/qm9/download.sh @@ -0,0 +1,3 @@ +wget https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/gdb9.tar.gz +tar -xvzf gdb9.tar.gz +wget https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm9.csv diff --git a/pyscf_ipu/direct/qm9/reproduce.py b/pyscf_ipu/direct/qm9/reproduce.py new file mode 100644 index 00000000..d2b446e3 --- /dev/null +++ b/pyscf_ipu/direct/qm9/reproduce.py @@ -0,0 +1,42 @@ +# reproduce qm9 labels; run download.sh to download qm9 first. +from pyscf import gto, scf, dft, __config__ +import pyscf +import pandas as pd +print(pyscf.__version__) +df = pd.read_csv('qm9.csv') +qm9_energy = df['u0'][0] - df['zpve'][0] +qm9_hlgap = df['gap'][0] + +mol = gto.Mole() +mol.atom = ''' + C -0.0127 1.0858 0.0080 + H 0.0022 -0.0060 0.0020 + H 1.0117 1.4638 0.0003 + H -0.5408 1.4475 -0.8766 + H -0.5238 1.4379 0.9064 +''' +mol.basis = '6-31G(2df,p)' +mol.build() + +# Run B3LYP calculation +method = dft.RKS(mol) +method.verbose = 4 +method.xc = 'B3LYPG' # b3lypG (G as in gaussain) +method.max_cycle = 50 +method.DIIS = pyscf.scf.diis.CDIIS +method.small_rho_cutoff = 1e-10 +method.diis_space = 8 +method.diis_start_cycle = 1 +method.damp = 5e-1 # damping factor +method.conv_tol = 1e-9 +method.conv_tol_grad = None # 1e-9 +method.grids.level = 3 +method.kernel() + +# Get total energy and HOMO-LUMO gap +energy = method.e_tot +homo, lumo = method.mo_energy[method.mo_occ>0].max(), method.mo_energy[method.mo_occ==0].min() +hlgap = lumo - homo + +print('qm9\t %10f %10f'%(qm9_energy, qm9_hlgap)) +print('pyscf\t %10f %10f'%( energy, hlgap)) diff --git a/pyscf_ipu/direct/sparse_symmetric_ERI.py b/pyscf_ipu/direct/sparse_symmetric_ERI.py new file mode 100644 index 00000000..24ba84f6 --- /dev/null +++ b/pyscf_ipu/direct/sparse_symmetric_ERI.py @@ -0,0 +1,308 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import pyscf +import numpy as np +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +from functools import partial +from icecream import ic +HYB_B3LYP = 0.2 + +def get_i_j(val): + i = (np.sqrt(1 + 8*val.astype(np.uint64)) - 1)//2 # no need for floor, integer division acts as floor. + j = (((val - i) - (i**2 - val))//2) + return i, j + +def _ijkl(value, symmetry, N, f): + #i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) + i, j, k, l = value[0], value[1], value[2], value[3] + return f(i,j,k,l,symmetry,N) +ijkl = jax.vmap(_ijkl, in_axes=(0, None, None, None)) + +def np_ijkl(value, symmetry, N, f): + #i, j, k, l = value[0].astype(np.uint32), value[1].astype(np.uint32), value[2].astype(np.uint32), value[3].astype(np.uint32) + i, j, k, l = value[:, 0], value[:, 1], value[:, 2], value[:, 3] + return f(i,j,k,l,symmetry,N) + + +def num_repetitions_fast(ij, kl): + i, j = get_i_j(ij) + k, l = get_i_j(kl) + + # compute: repetitions = 2^((i==j) + (k==l) + (k==i and l==j or k==j and l==i)) + repetitions = 2**( + np.equal(i,j).astype(np.uint64) + + np.equal(k,l).astype(np.uint64) + + (1 - ((1 - np.equal(k,i) * np.equal(l,j)) * + (1- np.equal(k,j) * np.equal(l,i))).astype(np.uint64)) + ) + return repetitions + +indices_func = lambda i,j,k,l,symmetry,N: jnp.array([i*N+j, j*N+i, i*N+j, j*N+i, k*N+l, l*N+k, k*N+l, l*N+k, + k*N+l, k*N+l, l*N+k, l*N+k, i*N+j, i*N+j, j*N+i, j*N+i, + k*N+j, k*N+i, l*N+j, l*N+i, i*N+l, i*N+k, j*N+l, j*N+k, + i*N+l, j*N+l, i*N+k, j*N+k, k*N+j, l*N+j, k*N+i, l*N+i])[symmetry] + +def _indices_func(i, j, k, l, symmetry, N): + if symmetry == 0: return i * N + j + elif symmetry == 1: return j * N + i + elif symmetry == 2: return i * N + j + elif symmetry == 3: return j * N + i + elif symmetry == 4: return k * N + l + elif symmetry == 5: return l * N + k + elif symmetry == 6: return k * N + l + elif symmetry == 7: return l * N + k + elif symmetry == 8 or symmetry == 9: return k * N + l + elif symmetry == 10 or symmetry == 11: return l * N + k + elif symmetry == 12 or symmetry == 13: return i * N + j + elif symmetry == 14 or symmetry == 15: return j * N + i + elif symmetry == 16: return k * N + j + elif symmetry == 17: return k * N + i + elif symmetry == 18: return l * N + j + elif symmetry == 19: return l * N + i + elif symmetry == 20: return i * N + l + elif symmetry == 21: return i * N + k + elif symmetry == 22: return j * N + l + elif symmetry == 23: return j * N + k + elif symmetry == 24: return i * N + l #j*N+l, i*N+k, j*N+k, + elif symmetry == 25: return j*N+l + elif symmetry == 26: return i*N+k + elif symmetry == 27: return j*N+k + elif symmetry == 28: return k * N + j + elif symmetry == 29: return l * N + j + elif symmetry == 30: return k * N + i + elif symmetry == 31: return l * N + i + + +def sparse_symmetric_einsum(nonzero_distinct_ERI, nonzero_indices, dm, foriloop): + dm = dm.reshape(-1) + diff_JK = jnp.zeros(dm.shape) + N = int(np.sqrt(dm.shape[0])) + + dnums = jax.lax.GatherDimensionNumbers( + offset_dims=(), + collapsed_slice_dims=(0,), + start_index_map=(0,)) + scatter_dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), + inserted_window_dims=(0,), + scatter_dims_to_operand_dims=(0,)) + Z = jnp.zeros((N**2,), dtype=dm.dtype) + + # todo: how much faster if we precompute dm/ss indices? + def iteration(symmetry, vals): + diff_JK = vals + is_K_matrix = (symmetry >= 8) + + def sequentialized_iter(i, vals): + # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) + # Trade-off: Using one function leads to smaller always-live memory. + diff_JK = vals + indices = nonzero_indices[i]#.astype(np.int32) # + eris = nonzero_distinct_ERI[i] + + dm_indices = ijkl(indices, symmetry+is_K_matrix*8, N, indices_func).reshape(-1, 1) + #dm_values = jnp.take(dm, dm_indices, axis=0)[:, 0] # for our special case the 50 lines of code reduces to the one line below. + dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) + dm_values = dm_values * eris + + ss_indices = ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) + # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. + diff_JK = diff_JK + jax.lax.scatter_add(Z, + ss_indices, dm_values, + scatter_dnums, indices_are_sorted=True, unique_indices=False, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ + *(-HYB_B3LYP/2)**is_K_matrix + + return diff_JK + + batches = nonzero_indices.shape[0] + + # forloop makes training slower but compile time faster. + if foriloop: + diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) + else: + for i in range(batches): + diff_JK = sequentialized_iter(i, diff_JK) + return diff_JK + + if foriloop: + diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + else: + for i in range(0, 16): + diff_JK = iteration(i, diff_JK) + #diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + #return jax.lax.psum(diff_JK, axis_name="p") + return diff_JK.reshape(N, N) + + +def sparse_einsum(nonzero_distinct_ERI, precomputed_indices, dm, foriloop): + dm = dm.reshape(-1) + diff_JK = jnp.zeros(dm.shape) + N = int(np.sqrt(dm.shape[0])) + + dnums = jax.lax.GatherDimensionNumbers( offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) + scatter_dnums = jax.lax.ScatterDimensionNumbers( update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,)) + Z = jnp.zeros((N**2,), dtype=dm.dtype) + + def iteration(symmetry, vals): + diff_JK = vals + is_K_matrix = (symmetry >= 8) + + def sequentialized_iter(i, vals): + # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) + # Trade-off: Using one function leads to smaller always-live memory. + diff_JK = vals + eris = nonzero_distinct_ERI[i] + + #dm_values = jnp.take(dm, dm_indices, axis=0)[:, 0] # for our special case the 50 lines of code reduces to the one line below. + dm_indices = precomputed_indices[symmetry, i, 0] + ss_indices = precomputed_indices[symmetry, i, 1] + dm_values = jax.lax.gather(dm, dm_indices, dimension_numbers=dnums, slice_sizes=(1,), mode=jax.lax.GatherScatterMode.FILL_OR_DROP) + dm_values = dm_values * eris + + #ss_indices = ijkl(indices, symmetry+8+is_K_matrix*8, N, indices_func) .reshape(-1,1) + # diff_JK = diff_JK + jax.lax.segment_sum( ...) # for our special case the 100 lines of code reduces to the one line below. + diff_JK = diff_JK + jax.lax.scatter_add(Z, ss_indices, dm_values, + scatter_dnums, indices_are_sorted=True, unique_indices=False, mode=jax.lax.GatherScatterMode.FILL_OR_DROP)\ + *(-HYB_B3LYP/2)**is_K_matrix + + return diff_JK + + batches = precomputed_indices.shape[1] + + # forloop makes training slower but compile time faster. + if foriloop: + diff_JK = jax.lax.fori_loop(0, batches, sequentialized_iter, diff_JK) + else: + for i in range(batches): + diff_JK = sequentialized_iter(i, diff_JK) + return diff_JK + + if foriloop: + diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + else: + for i in range(0, 16): + diff_JK = iteration(i, diff_JK) + #diff_JK = jax.lax.fori_loop(0, 16, iteration, diff_JK) + #return jax.lax.psum(diff_JK, axis_name="p") + return diff_JK.reshape(N, N) + + + +def precompute_indices(nonzero_indices, N): + + def iteration(symmetry): + is_K_matrix = (symmetry >= 8) + + def sequentialized_iter(i): + # Generalized J/K computation: does J when symmetry is in range(0,8) and K when symmetry is in range(8,16) + # Trade-off: Using one function leads to smaller always-live memory. + indices = nonzero_indices[i] + dm_indices = np_ijkl(indices, symmetry+is_K_matrix*8, N, _indices_func).reshape(-1, 1) + ss_indices = np_ijkl(indices, symmetry+8+is_K_matrix*8, N, _indices_func) .reshape(-1,1) + + return dm_indices, ss_indices + + batches = nonzero_indices.shape[0] + + # forloop makes training slower but compile time faster. + _indices = [None for _ in range(batches)] + for i in range(batches): + _indices[i] = sequentialized_iter(i) + return _indices + + _indices = [None for _ in range(16)] + for i in range(0, 16): + _indices[i] = iteration(i) + return np.array(_indices ) + +if __name__ == "__main__": + import time + import argparse + parser = argparse.ArgumentParser(prog='', description='', epilog='') + parser.add_argument('-backend', default="cpu"), + parser.add_argument('-natm', default=3), + parser.add_argument('-test', action="store_true") + parser.add_argument('-prof', action="store_true") + parser.add_argument('-batches', default=5) + parser.add_argument('-skip', action="store_true") + + args = parser.parse_args() + backend = args.backend + + natm = int(args.natm) + nipu = 1 + + start = time.time() + + mol = pyscf.gto.Mole(atom="".join(f"C 0 {1.54*j} {1.54*i};" for i in range(natm) for j in range(natm))) + #mol = pyscf.gto.Mole(atom="".join(f"C 0 {15.4*j} {15.4*i};" for i in range(1) for j in range(75))) + mol.build() + N = mol.nao_nr() + print("N %i"%mol.nao_nr()) + print("NxN:", (N**2, N**2)) + print("Naive operations: ", N**4*2/10**9, "[Giga]") + if not args.skip: dense_ERI = mol.intor("int2e_sph", aosym="s1") + distinct_ERI = mol.intor("int2e_sph", aosym="s8") + #distinct_ERI[np.abs(distinct_ERI)<1e-9] = 0 # zero out stuff + dm = pyscf.scf.hf.init_guess_by_minao(mol) + scale = HYB_B3LYP/2 + if not args.skip: + J = np.einsum("ijkl,ji->kl", dense_ERI, dm) + K = np.einsum("ijkl,jk->il", dense_ERI, dm) + truth = J - K / 2 * HYB_B3LYP + + nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) + nonzero_distinct_ERI = distinct_ERI[nonzero_indices]#.astype(np.float32) + print("Nonzero Operations:", nonzero_indices.size*8*2/10**9, "[Giga]") + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + nonzero_distinct_ERI = nonzero_distinct_ERI / rep + dm = dm.reshape(-1) + diff_JK = np.zeros(dm.shape) + + batches = int(args.batches) + remainder = nonzero_indices.shape[0] % (nipu*batches) + + if remainder != 0: + print(nipu*batches-remainder, ij.shape) + ij = np.pad(ij, ((0,nipu*batches-remainder))) + kl = np.pad(kl, ((0,nipu*batches-remainder))) + nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) + + ij = ij.reshape(nipu, batches, -1) + kl = kl.reshape(nipu, batches, -1) + nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(nipu, batches, -1) + + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(nipu, batches, -1, 4).astype(np.int32) + #nonzero_indices = jax.lax.bitcast_convert_type(nonzero_indices, np.float16) + + #diff_JK = jax.pmap(sparse_symmetric_einsum, in_axes=(0,0,None,None), static_broadcasted_argnums=(3,), backend=backend, axis_name="p")(nonzero_distinct_ERI, nonzero_indices, dm, args.backend) + diff_JK = jax.jit(sparse_symmetric_einsum, static_argnums=(3,), backend=backend)(nonzero_distinct_ERI[0], nonzero_indices[0], dm, args.backend) + #diff_JK = jax.jit(sparse_symmetric_einsum, backend=backend, static_argnums=(3,))(nonzero_distinct_ERI[0], nonzero_indices[0], dm, False) + + indices = precompute_indices(nonzero_indices[0], N) + print(np.max(indices)) # this is just N**2! + indices = indices.astype(np.int16) + print(np.max(indices)) + print(nonzero_distinct_ERI.nbytes/10**9, nonzero_indices.nbytes/10**9, indices.nbytes/10**9) + print(nonzero_distinct_ERI.shape, nonzero_indices.shape, indices.shape) + print(np.max(indices)) + + _diff_JK = jax.jit(sparse_einsum, static_argnums=(3,), backend=backend)(nonzero_distinct_ERI[0], indices, dm, args.backend) + + + if args.skip: + exit() + + diff_JK = diff_JK.reshape(N, N) + print(diff_JK.reshape(-1)[::51]) + print(truth.reshape(-1)[::51]) + print(np.max(np.abs(diff_JK.reshape(-1) - truth.reshape(-1)))) + print(np.max(np.abs(_diff_JK.reshape(-1) - truth.reshape(-1)))) + assert np.allclose(diff_JK, truth, atol=1e-6) + assert np.allclose(_diff_JK, truth, atol=1e-6) + print("PASSED!") + \ No newline at end of file diff --git a/pyscf_ipu/direct/train.py b/pyscf_ipu/direct/train.py new file mode 100644 index 00000000..fee2dbcf --- /dev/null +++ b/pyscf_ipu/direct/train.py @@ -0,0 +1,1518 @@ +import os +os.environ['OMP_NUM_THREADS'] = '8' +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np +import pyscf +import optax +from icecream import ic +from exchange_correlation.b3lyp import b3lyp, _b3lyp, vxc_b3lyp +from tqdm import tqdm +import time +from transformer import transformer, transformer_init +import pandas as pd +import math +from functools import partial +import pickle + +cfg, HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = None, 27.2114079527, 1e-20, 0.2 + +def T(x): return jnp.transpose(x, (0,2,1)) + +B, BxNxN, BxNxK = None, None, None + +# Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. +def dm_energy(W: BxNxK, state, normal, nn, cfg, opts):#): + if nn: + W = jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0))(cfg, \ + W, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + + W = W.astype(jnp.float64) + # we can interpret state.H_core + W as hamiltonian, and predict hlgap from these! + L_inv_Q: BxNxN = state.L_inv_T @ jnp.linalg.eigh(state.L_inv @ (state.H_core + W) @ state.L_inv_T)[1] # O(B*N*K^2) FLOP O(B*N*K) FLOP/FLIO + density_matrix: BxNxN = 2 * (L_inv_Q*state.mask) @ T(L_inv_Q) # O(B*N*K^2) FLOP/FLIO + E_xc: B = exchange_correlation(density_matrix, state, normal, opts.xc_f32) # O(B*gsize*N^2) FLOP O(gsize*N^2) FLIO + diff_JK: BxNxN = JK(density_matrix, state, normal, opts.foriloop, opts.eri_f32) # O(B*num_ERIs) FLOP O(num_ERIs) FLIO + energies: B = E_xc + state.E_nuc + jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(W.shape[0], -1), axis=-1) + energy: float = jnp.sum(energies) + return energy, (energies, E_xc, density_matrix, W) + +def sparse_mult(values, dm, state, gsize): + in_ = dm.take(state.cols, axis=0) + prod = in_*values[:, None] + return jax.ops.segment_sum(prod, state.rows, gsize) + +def exchange_correlation(density_matrix: BxNxN, state, normal, xc_f32): + _, _, gsize, N = state.grid_AO.shape + B = density_matrix.shape[0] + if normal: + grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (B, N, N) = O(B gsize N^2) + rho = jnp.sum(grid_AO_dm.reshape(B, 1, gsize, N) * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + else: + if xc_f32: density_matrix.astype(jnp.float32) + if False: + main: BxGsizexN = state.main_grid_AO @ density_matrix # (1, gsize, N) @ (B, N, N) = O(B gsize N^2) FLOPs and O(gsize*N + N^2 +B * gsize * N) FLIOs + correction: BxGsizexN = jax.vmap(sparse_mult, in_axes=(0,0,None, None))(state.sparse_diffs_grid_AO, density_matrix, state, gsize) + rho_a = jnp.einsum("bpij,bqij->bpi", state.grid_AO, main.reshape(B,1,gsize,N)) + rho_b = jnp.einsum("bpij,bqij->bpi", state.grid_AO, correction.reshape(B,1,gsize,N)) + rho = rho_a - rho_b + else: + grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (B, N, N) = O(B gsize N^2) + rho = jnp.sum(grid_AO_dm.reshape(B, 1, gsize, N) * state.grid_AO, axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + rho = rho.astype(jnp.float64) + + + E_xc = jax.vmap(_b3lyp, in_axes=(0, None))(rho, EPSILON_B3LYP).reshape(B, gsize) + E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) + return E_xc + +def JK(density_matrix, state, normal, jax_foriloop, eri_f32): + if normal: + J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) + K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) + diff_JK = J - K / 2 * HYB_B3LYP + else: + from sparse_symmetric_ERI import sparse_symmetric_einsum, sparse_einsum + + if eri_f32: density_matrix = density_matrix.astype(jnp.float32) + + '''diff_JK: BxNxN = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.nonzero_indices[0], + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.indxs, + density_matrix, + jax_foriloop + )''' + + diff_JK: BxNxN = jax.vmap(sparse_einsum, in_axes=(None, None, 0, None))( + state.nonzero_distinct_ERI[0], + state.precomputed_nonzero_indices, + density_matrix, + jax_foriloop + ) + correction = jax.vmap(sparse_einsum, in_axes=(0, None, 0, None))( + state.diffs_ERI, + state.precomputed_indxs, + density_matrix, + jax_foriloop + ) + diff_JK: BxNxN = diff_JK - correction + + return diff_JK.astype(jnp.float64) + +def nao(atom, basis): + m = pyscf.gto.Mole(atom='%s 0 0 0; %s 0 0 1;'%(atom, atom), basis=basis) + m.build() + return m.nao_nr()//2 + +def batched_state(mol_str, opts, bs, wiggle_num=0, + do_pyscf=True, validation=False, + extrapolate=False, + pad_electrons=45, + pad_diff_ERIs=50000, + pad_distinct_ERIs=120000, + pad_grid_AO=25000, + pad_nonzero_distinct_ERI=200000, + pad_sparse_diff_grid=200000, + mol_idx=42, + init_phi_psi=None, + inference=False, + inference_psi_step=5, # degrees + ): + # pad molecule if using nn. + if not opts.nn: + pad_electrons, pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = \ + -1, -1, -1, -1, -1, -1 + + max_pad_electrons, max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid = \ + -1, -1, -1, -1, -1, -1 + if opts.benzene and opts.nn: pad_electrons = 30 + if opts.hydrogens: + pad_diff_ERIs = 5000 + pad_distinct_ERIs = 20000 + pad_grid_AO = 2200 + pad_nonzero_distinct_ERI = 20000 + pad_sparse_diff_grid = 20000 + + if opts.qm9: + pad_electrons=60 + '''pad_diff_ERIs=120000 + pad_distinct_ERIs=400000 + pad_grid_AO=50000 + pad_nonzero_distinct_ERI=400000 + pad_sparse_diff_grid=400000''' + #padding_estimate = [37426, 149710, 17010, 140122, 138369] + padding_estimate = [48330, 163222, 17034, 159361, 139505] + padding_estimate = [int(a*1.1) for a in padding_estimate] + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + if opts.alanine: + # todo: (adam) the ERI padding may change when rotating molecule more! + pad_electrons = 70 + # padding is estimated/printed when running; copy those numbers to the list below. + padding_estimate = [210745, 219043, 18084, 193830, 1105268] + # 213973 218912 18084 195723 1105847] + # add 10% + padding_estimate = [int(a*1.1) for a in padding_estimate] + # name variables correctly + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = padding_estimate + + pad_diff_ERIs *= int(8/opts.eri_bs) + pad_distinct_ERIs *= int(8/opts.eri_bs) + pad_nonzero_distinct_ERI *= int(8/opts.eri_bs) + + if opts.waters: + pad_diff_ERIs, pad_distinct_ERIs, pad_grid_AO, pad_nonzero_distinct_ERI, pad_sparse_diff_grid = [a//3 for a in [ pad_diff_ERIs , pad_distinct_ERIs , pad_grid_AO , pad_nonzero_distinct_ERI , pad_sparse_diff_grid ]] + + mol = build_mol(mol_str, opts.basis) + pad_electrons = min(pad_electrons, mol.nao_nr()) + + # Set seed to ensure different rotation; initially all workers did same rotation! + np.random.seed(mol_idx) + + if opts.waters: + water1_xyz = np.array([mol_str[i][1] for i in range(0,3)]) + water2_xyz = np.array([mol_str[i][1] for i in range(3,6)]) + + if opts.qm9: + atoms = np.array([mol_str[i][1] for i in range(0,3)]) + # pick random atom to permute (of the first 9 heavy ones) + atom_num = int(np.random.uniform(0, 8)) + + if opts.alanine: + # train on [-180, 180], validate [-180, 180] extrapolate [-360, 360]\[180, -180] + # todo: draw picture (in training loop) + if extrapolate and not inference: + phi, psi = [float(a) for a in np.random.uniform(180, 360, 2)] + elif inference: + phi, psi = init_phi_psi + else: + phi, psi = [float(a) for a in np.random.uniform(0, 180, 2)] + + angles = [] + + states = [] + for iteration in range(bs): + import copy + new_str = copy.deepcopy(mol_str) + + if opts.alanine: + from rdkit import Chem + from rdkit.Chem import AllChem + pdb_file = 'alanine.pdb' + molecule = Chem.MolFromPDBFile(pdb_file, removeHs=False) + # tried reading from pdb_block; caused parallel dataloader pickle to break. + AllChem.EmbedMolecule(molecule) + #AllChem.UFFOptimizeMolecule(molecule) + phi_atoms = [4, 6, 8, 14] # indices for phi dihedral + psi_atoms = [6, 8, 14, 16] # indices for psi dihedral''' + + def xyz(atom): return np.array([atom.x, atom.y, atom.z]).reshape(1, 3) + def get_atom_positions(mol): + conf = mol.GetConformer() + return np.concatenate([xyz(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], axis=0) + + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) + + # todo: save=wandb.log({"pair": angle1, angle2, NN_energy ) (rotation, NN_energy) for train/val molecule (for val also save PySCF energy) + # only saving angles (angle not paired up with energy) + AllChem.SetDihedralDeg(molecule.GetConformer(), *phi_atoms, phi) + angle = psi + float(np.random.uniform(0, opts.rotate_deg, 1)) # perhaps add 45 and mod 360? + + # todo: check math whether val/extra/train have uniform distribution on their respective domains. + if extrapolate and not inference: # make sure angle is in [] + angle = angle % 180 + 180 # angle should be in [180, 360] + elif inference: + angle = psi + iteration * inference_psi_step # overwrite the angle when in inference mode with fixed, not randomized, step + angle = angle % 360 # angle should be [0, 360] for heatmap + else: # validation + angle = angle % 180 # angle should be [0, 180] + + AllChem.SetDihedralDeg(molecule.GetConformer(), *psi_atoms, angle ) + pos = get_atom_positions(molecule) + angles.append((phi, angle)) + + for j in range(len(new_str)): new_str[j][1] = tuple(pos[j]) + + '''if iteration == 0 and opts.wandb: + from plot import create_rdkit_mol + import wandb + wandb.log({"mol_valid=%s"%validation: create_rdkit_mol(str, pos) })''' + + if opts.waters: # todo: rotate both water molecules and draw x=phi, y=psi. + rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] + center = water2_xyz.mean(axis=0) + water_xyz = np.dot(water2_xyz - center, rotation_matrix) + center + + new_str[3][1] = tuple(water_xyz[0]) + new_str[4][1] = tuple(water_xyz[1]) + new_str[5][1] = tuple(water_xyz[2]) + + '''if opts.wandb and iteration == 0: + from plot import create_rdkit_mol + import wandb + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' + + elif opts.qm9: + # todo: find dihedral to rotate over similar to alanine dipeptide. + # broken; rotate first three atoms around their center of mass + # this breaks molecule; should use dihedral angle as done with the dipeptide. + #rotation_matrix = np.linalg.qr(np.random.normal(size=(3,3)))[0] + #center = atoms.mean(axis=0) + #rotated_atoms = np.dot(atoms - center, rotation_matrix) + center + + # for extrapolation, do even more. + + if iteration == 0 and (validation or extrapolate): + pass + else: + #new_str[0][1] = tuple(atoms[0] + np.random.normal(0, opts.wiggle_var, (3))) + new_str[atom_num][1] = tuple(atoms[atom_num] + np.random.normal(0, opts.wiggle_var, (3))) + #new_str[1][1] = tuple(atoms[1] + np.random.normal(0, opts.wiggle_var, (3))) + #new_str[2][1] = tuple(atoms[2] + np.random.normal(0, opts.wiggle_var, (3))) + + '''if opts.wandb and iteration == 0: + from plot import create_rdkit_mol + import wandb + str = [new_str[j][0] for j in range(len(new_str))] + pos = np.concatenate([np.array(new_str[j][1]).reshape(1, 3) for j in range(len(new_str))]) + wandb.log({"%s_mol_%i"%({True: "valid", False: "train"}[validation], iteration): create_rdkit_mol(str, pos) })''' + + if iteration == 0: + state = init_dft(new_str, opts, do_pyscf=do_pyscf, pad_electrons=pad_electrons) + c, w = state.grid_coords, state.grid_weights + elif iteration <= 1 or not opts.prof: # when profiling create fake molecule to skip waiting + state = init_dft(new_str, opts, c, w, do_pyscf=do_pyscf and iteration < 80, state=state, pad_electrons=pad_electrons) + + states.append(state) + + # If we add energy here we get plot basically! + # todo: save and store in training loop, then we can match with energy + # can't get to work in wandb, but can just use download api and the plot. + '''if opts.alanine and opts.wandb: + for phi, psi in angles: + if not validation: + wandb.log({"phi_train": phi , "psi_train": psi}) + else: + wandb.log({"phi_valid": phi, "psi_valid": psi})''' + state = cats(states) + N = state.N[0] + + # Compute ERI sparsity. + nonzero = [] + for e, i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): + abs = np.abs(e) + indxs = abs < opts.eri_threshold #1e-10 + e[indxs] = 0 + nonzero.append(np.nonzero(e)[0]) + + # Merge nonzero indices and prepare (ij, kl). + # rep is the number of repetitions we include in the sparse representation. + #nonzero_indices = np.union1d(nonzero[0], nonzero[1]) + union = nonzero[0] + for i in range(1, len(nonzero)): + union = np.union1d(union, nonzero[i]) + nonzero_indices = union + from sparse_symmetric_ERI import get_i_j, num_repetitions_fast + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + + batches = opts.eri_bs + es = [] + for e,i in zip(state.nonzero_distinct_ERI, state.nonzero_indices): + nonzero_distinct_ERI = e[nonzero_indices] / rep + remainder = nonzero_indices.shape[0] % (batches) + if remainder != 0: nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,batches-remainder)) + + nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(batches, -1) + es.append(nonzero_distinct_ERI) + + state.nonzero_distinct_ERI = np.concatenate([np.expand_dims(a, axis=0) for a in es]) + + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + + if remainder != 0: + i = np.pad(i, ((0,batches-remainder))) + j = np.pad(j, ((0,batches-remainder))) + k = np.pad(k, ((0,batches-remainder))) + l = np.pad(l, ((0,batches-remainder))) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4).astype(np.int32) # todo: use int16 or int32 here? + state.nonzero_indices = nonzero_indices + + # batching (w/ same sparsity pattern across batch) allows precomputing all {ss,dm}_indices instead of computing in sparse_sym_eri every iteration. + # function below does this. + # todo: consider removing, didn't get expecting 3x (only 5%; not sure if additional memory/complication justifies). + from sparse_symmetric_ERI import precompute_indices + + if opts.normal: diff_state = None + else: + main_grid_AO = state.grid_AO[:1] + diffs_grid_AO = main_grid_AO - state.grid_AO + rows, cols = np.nonzero(np.max(diffs_grid_AO[:, 0]!=0, axis=0)) + sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] + + # use the same sparsity pattern across a batch. + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices.reshape(1, batches, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(batches, -1) > 1e-10 + + diff_ERIs = diff_ERIs[:, nzr].reshape(bs, -1) + diff_indxs = diff_indxs[:, nzr].reshape(-1, 4) + + remainder = np.sum(nzr) % batches + if remainder != 0: + diff_ERIs = np.pad(diff_ERIs, ((0,0),(0,batches-remainder))) + diff_indxs = np.pad(diff_indxs, ((0,batches-remainder),(0,0))) + + diff_ERIs = diff_ERIs.reshape(bs, batches, -1) + diff_indxs = diff_indxs.reshape(batches, -1, 4) + + precomputed_indxs = precompute_indices(diff_indxs, N).astype(np.int16) + + if pad_diff_ERIs == -1: + state.indxs=diff_indxs + state.diffs_ERI=diff_ERIs + assert False, "deal with precomputed_indxs; only added in else branch below" + else: + max_pad_diff_ERIs = diff_ERIs.shape[2] + # pad ERIs with 0 and indices with -1 so they point to 0. + assert diff_indxs.shape[1] == diff_ERIs.shape[2] + pad = pad_diff_ERIs - diff_indxs.shape[1] + assert pad > 0, (pad_diff_ERIs, diff_indxs.shape[1]) + state.indxs = np.pad(diff_indxs, ((0,0), (0, pad), (0, 0)), 'constant', constant_values=(-1)) + state.diffs_ERI = np.pad(diff_ERIs, ((0,0), (0, 0), (0, pad))) # pad zeros + #print(diff_indxs.shape, precomputed_indxs.shape) + state.precomputed_indxs = np.pad(precomputed_indxs, ((0,0), (0,0),(0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + + #if opts.wandb: wandb.log({"pad_diff_ERIs": pad/diff_ERIs.shape[2]}) + + state.rows=rows + state.cols=cols + + state.main_grid_AO=main_grid_AO[:1, 0] + + state.sparse_diffs_grid_AO = sparse_diffs_grid_AO + #state.diffs_grid_AO = diffs_grid_AO # this isn't used for energy eval + + if pad_sparse_diff_grid != -1: + max_pad_sparse_diff_grid = state.rows.shape[0] + assert state.sparse_diffs_grid_AO.shape[1] == state.rows.shape[0] + assert state.sparse_diffs_grid_AO.shape[1] == state.cols.shape[0] + pad = pad_sparse_diff_grid - state.rows.shape[0] + assert pad >= 0, (pad_sparse_diff_grid, state.rows.shape[0]) + state.rows = np.pad(state.rows, (0,pad)) + state.cols = np.pad(state.cols, (0,pad)) + state.sparse_diffs_grid_AO = np.pad(state.sparse_diffs_grid_AO, ((0,0),(0,pad))) + + #if opts.wandb: wandb.log({"pad_sparse_diff_grid": pad/state.sparse_diffs_grid_AO.shape[1]}) + + #state.grid_AO = state.grid_AO[:1] + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[:1] + + state.nonzero_indices = np.expand_dims(state.nonzero_indices, axis=0) + + # todo: looks like we're padding, then looking for zeros, then padding; this can be simplified. + if pad_distinct_ERIs != -1: + max_pad_distinct_ERIs = state.nonzero_distinct_ERI.shape[2] + assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] + pad = pad_distinct_ERIs - state.nonzero_distinct_ERI.shape[2] + assert pad > 0, (pad_distinct_ERIs, state.nonzero_distinct_ERI.shape[2]) + state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0), (0,0), (0, pad), (0,0)), 'constant', constant_values=(-1)) + state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0), (0,0), (0, pad))) # pad zeros + + #if opts.wandb: wandb.log({"pad_distinct_ERIs": pad/state.nonzero_distinct_ERI.shape[2]}) + + if pad_grid_AO != -1: + max_pad_grid_AO = state.grid_AO.shape[2] + + prev_size = state.grid_AO.shape[2] + assert state.grid_AO.shape[2] == state.grid_weights.shape[1] + assert state.grid_AO.shape[2] == state.grid_coords.shape[1] + assert state.grid_AO.shape[2] == state.main_grid_AO.shape[1] + #assert state.grid_AO.shape[2] == state.diffs_grid_AO.shape[2] + pad = pad_grid_AO - state.grid_AO.shape[2] + assert pad > 0, (pad_grid_AO, state.grid_AO.shape[2]) + state.grid_AO = np.pad(state.grid_AO, ((0,0),(0,0), (0,pad), (0,0))) + state.grid_weights = np.pad(state.grid_weights, ((0,0),(0,pad))) + state.grid_coords = np.pad(state.grid_coords, ((0,0),(0,pad),(0,0))) + state.main_grid_AO = np.pad(state.main_grid_AO, ((0,0),(0,pad),(0,0))) + #state.diffs_grid_AO = np.pad(state.diffs_grid_AO, ((0,0),(0,0),(0,pad),(0,0))) + + #if opts.wandb: + # wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2], + # "pad_grid_AO_prev": prev_size, + # "pad_grid_AO_pad": pad, + # "pad_grid_AO_target": pad_grid_AO}) + + + # todo: make this into a variable we can control from commandline. + indxs = np.abs(state.nonzero_distinct_ERI ) > opts.eri_threshold #1e-9 + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI[indxs] + state.nonzero_indices = state.nonzero_indices[indxs] + remainder = state.nonzero_indices.shape[0] % batches + + if remainder != 0: + state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, (0,batches-remainder)) + state.nonzero_indices = np.pad(state.nonzero_indices, ((0,batches-remainder), (0,0))) + + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.reshape(1, batches, -1) + state.nonzero_indices = state.nonzero_indices.reshape(1, batches, -1, 4) + + precomputed_nonzero_indices = precompute_indices(state.nonzero_indices[0], N).astype(np.int16) + #print(state.nonzero_indices.shape, precomputed_nonzero_indices.shape) + + if pad_nonzero_distinct_ERI != -1: + max_pad_nonzero_distinct_ERI = state.nonzero_distinct_ERI.shape[2] + + assert state.nonzero_distinct_ERI.shape[2] == state.nonzero_indices.shape[2] + pad = pad_nonzero_distinct_ERI - state.nonzero_distinct_ERI.shape[2] + assert pad >= 0, (pad_nonzero_distinct_ERI, state.nonzero_distinct_ERI.shape[2]) + state.nonzero_distinct_ERI = np.pad(state.nonzero_distinct_ERI, ((0,0),(0,0),(0,pad))) + state.nonzero_indices = np.pad(state.nonzero_indices, ((0,0),(0,0),(0,pad), (0,0)), 'constant', constant_values=(-1)) + + state.precomputed_nonzero_indices = np.pad(precomputed_nonzero_indices, ((0,0), (0,0), (0,0), (0, pad),(0,0)), 'constant', constant_values=(-1)) + #print(state.precomputed_nonzero_indices.shape, state.nonzero_indices.shape) + + #if opts.wandb: wandb.log({"pad_grid_AO": pad/state.grid_AO.shape[2]}) + + state.pad_sizes = np.array([ + max_pad_diff_ERIs, max_pad_distinct_ERIs, max_pad_grid_AO, + max_pad_nonzero_distinct_ERI, max_pad_sparse_diff_grid]) + + if opts.eri_f32: + state.nonzero_distinct_ERI = state.nonzero_distinct_ERI.astype(jnp.float32) + state.diffs_ERI = state.diffs_ERI.astype(jnp.float32) + + if opts.xc_f32: + state.main_grid_AO = state.main_grid_AO.astype(jnp.float32) + state.grid_AO = state.grid_AO.astype(jnp.float32) + state.sparse_diffs_grid_AO = state.sparse_diffs_grid_AO.astype(jnp.float32) + + return state + + +def nanoDFT(mol_str, opts): + print() + # Initialize validation set. + # This consists of DFT tensors initialized with PySCF/CPU. + np.random.seed(42) + + if opts.wandb: + import wandb + if opts.alanine: + run = wandb.init(project='ndft_alanine') + elif opts.qm9: + run = wandb.init(project='ndft_qm9') + else: + run = wandb.init(project='ndft') + opts.name = run.name + + else: + opts.name = "%i"%time.time() + + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + + global cfg + '''Model ViT model embedding #heads #layers #params training throughput + dimension resolution (im/sec) + DeiT-Ti N/A 192 3 12 5M 224 2536 + DeiT-S N/A 384 6 12 22M 224 940 + DeiT-B ViT-B 768 12 12 86M 224 292 + Parameters Layers dmodel + 117M 12 768 + 345M 24 1024 + 762M 36 1280 + 1542M 48 1600 + ''' + if opts.tiny: # 5M + d_model= 192 + n_heads = 6 + n_layers = 12 + if opts.small: + d_model= 384 + n_heads = 6 + n_layers = 12 + if opts.base: + d_model= 768 + n_heads = 12 + n_layers = 12 + if opts.medium: + d_model= 1024 + n_heads = 16 + n_layers = 24 + if opts.large: + d_model= 1280 + n_heads = 16 + n_layers = 36 + if opts.xlarge: + d_model= 1600 + n_heads = 25 + n_layers = 48 + + if opts.nn: + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + params = params.to_float32() + + if opts.resume: + print("loading checkpoint") + params = pickle.load(open("checkpoints/%s_model.pickle"%opts.resume, "rb")) + print("done loading. ") + + + if opts.nn: + #https://arxiv.org/pdf/1706.03762.pdf see 5.3 optimizer + + # try to mimic karpathy as closely as possible ;) + # https://github.com/karpathy/nanoGPT/blob/master/train.py + # still differs on + # [ ] weight initialization + + def custom_schedule(it, learning_rate=opts.lr, min_lr=opts.lr/10, warmup_iters=2000, lr_decay_iters=600000): # 600k/30 = 20k; so hit mi + #return learning_rate * it / warmup_iters # to allow jax jit? + # allow jax jit + '''if it < warmup_iters: return learning_rate * it / warmup_iters # linearly increase until hit warmup iters. + if it > lr_decay_iters: return min_lr # after decay (600k iterations) go to 10x lower + + # in between, decay learning rate using this function; this is from 2k steps to 600k steps + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) + return min_lr + coeff * (learning_rate - min_lr)''' + #if it < warmup_iters: return learning_rate * it / warmup_iters + cond1 = (it < warmup_iters) * learning_rate * it / warmup_iters + cond2 = (it > lr_decay_iters) * min_lr + + decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) + coeff = 0.5 * (1.0 + jnp.cos(jnp.pi * decay_ratio)) + cond3 = (it >= warmup_iters) * (it <= lr_decay_iters) * (min_lr + coeff * (learning_rate - min_lr)) + return cond1 + cond2 + cond3 + + adam = optax.chain( + optax.clip_by_global_norm(1), + optax.scale_by_adam(b1=0.9, b2=0.95, eps=1e-12), + optax.add_decayed_weights(0.1),#, configure_decay_mask(params)), + optax.scale_by_schedule(custom_schedule), + optax.scale(-1), + ) + + w = params + + + from torch.utils.data import DataLoader, Dataset + class OnTheFlyQM9(Dataset): + # prepares dft tensors with pyscf "on the fly". + # dataloader is very keen on throwing segfaults (e.g. using jnp in dataloader throws segfaul). + # problem: second epoch always gives segfault. + # hacky fix; make __len__ = real_length*num_epochs and __getitem__ do idx%real_num_examples + def __init__(self, opts, nao=294, train=True, num_epochs=10**9, extrapolate=False): + # only take molecules with use {CNOFH}, nao=nao and spin=0. + df = pd.read_pickle("alchemy/processed_atom_9.pickle") # spin=0 and only CNOFH molecules + if nao != -1: df = df[df["nao"]==nao] + # df.sample is not deterministic; moved to pre-processing, so file is shuffled already. + # this shuffling is important, because it makes the last 10 samples iid (used for validation) + #df = df.sample(frac=1).reset_index(drop=True) # is this deterministic? + + if train: self.mol_strs = df["pyscf"].values[:-10] + else: self.mol_strs = df["pyscf"].values[-10:] + #print(df["pyscf"].) # todo: print smile strings + + self.num_epochs = num_epochs + self.opts = opts + self.validation = not train + self.extrapolate = extrapolate + + self.benzene = [ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ] + self.waters = [ + ["O", (-1.464, 0.099, 0.300)], + ["H", (-1.956, 0.624, -0.340)], + ["H", (-1.797, -0.799, 0.206)], + ["O", ( 1.369, 0.146, -0.395)], + ["H", ( 1.894, 0.486, 0.335)], + ["H", ( 0.451, 0.165, -0.083)] + ] + + if opts.benzene: self.mol_strs = [self.benzene] + if opts.waters: self.mol_strs = [self.waters] + if opts.alanine: self.mol_strs = mol_str + + if train: self.bs = opts.bs + else: self.bs = opts.val_bs + + def __len__(self): + return len(self.mol_strs)*self.num_epochs + + def __getitem__(self, idx): + return batched_state(self.mol_strs[idx%len(self.mol_strs)], self.opts, self.bs, \ + wiggle_num=0, do_pyscf=self.validation or self.extrapolate, validation=False, \ + extrapolate=self.extrapolate, mol_idx=idx) + + val_qm9 = OnTheFlyQM9(opts, train=False) + ext_qm9 = OnTheFlyQM9(opts, extrapolate=True) + + # parallel dataloader bug; precompute here is not slow but causes dataloader later to die. + # run once to quickly precompute. + if opts.precompute: + val_state = val_qm9[0] + ext_state = ext_qm9[0] + exit() + + qm9 = OnTheFlyQM9(opts, train=True) + if opts.workers != 0: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, prefetch_factor=2, collate_fn=lambda x: x[0]) + else: train_dataloader = DataLoader(qm9, batch_size=1, pin_memory=True, shuffle=False, drop_last=True, num_workers=opts.workers, collate_fn=lambda x: x[0]) + pbar = tqdm(train_dataloader) + + + else: + states = [batched_state(mol_str[0], opts, opts.bs, do_pyscf=True)] + [batched_state(mol_str[i], opts, opts.bs, do_pyscf=False) for i in range(opts.states-1)] + class DummyIterator: + def __init__(self, item): self.item = item + def __iter__(self): return self + def __next__(self): return self.item + train_dataloader = DummyIterator(states[0]) + pbar = tqdm(train_dataloader) + w = states[0].init + adam = optax.adabelief(opts.lr) + summary(states[0]) + + vandg = jax.jit(jax.value_and_grad(dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + valf = jax.jit(dm_energy, backend=opts.backend, static_argnames=("normal", 'nn', "cfg", "opts")) + adam_state = adam.init(w) + + if opts.resume: + print("loading adam state") + adam_state = pickle.load(open("checkpoints/%s_adam_state.pickle"%opts.resume, "rb")) + print("done") + + w, adam_state = jax.device_put(w), jax.device_put(adam_state) + + + @partial(jax.jit, backend=opts.backend) + def update(w, adam_state, accumulated_grad): + accumulated_grad = jax.tree_map(lambda x: x / opts.bs, accumulated_grad) + updates, adam_state = adam.update(accumulated_grad, adam_state, w) + w = optax.apply_updates(w, updates) + return w, adam_state + + if opts.wandb: + if not opts.nn: total_params = -1 + wandb.log({'total_params': total_params, 'batch_size': opts.bs, 'lr': opts.lr }) + + min_val, min_dm, mins, valid_str, step, val_state, ext_state = 0, 0, np.ones(opts.bs)*1e6, "", 0, None, None + t0, load_time, train_time, val_time, plot_time = time.time(), 0, 0, 0, 0 + + paddings = [] + states = [] + for iteration, state in enumerate(pbar): + if iteration == 0: summary(state) + state = jax.device_put(state) + + # Estimate max padding. + if iteration < 100: + paddings.append(state.pad_sizes.reshape(1, -1)) + _paddings = np.concatenate(paddings, axis=0) + print(np.max(_paddings, 0)) + + dct = {} + dct["iteraton"] = iteration + + states.append(state) + if len(states) > opts.mol_repeats: states.pop(0) + + + load_time, t0 = time.time()-t0, time.time() + if opts.checkpoint != -1 and iteration % opts.checkpoint == 0: # and iteration > 0: + t0 = time.time() + try: + name = opts.name.replace("-", "_") + path_model = "checkpoints/%s_%i_model.pickle"%(name, iteration) + path_adam = "checkpoints/%s_%i_adam_state.pickle"%(name, iteration) + print("trying to checkpoint to %s and %s"%(path_model, path_adam)) + pickle.dump(jax.device_get(w), open(path_model, "wb")) + pickle.dump(jax.device_get(adam_state), open(path_adam, "wb")) + print("done!") + print("\t-resume \"%s\""%(path_model.replace("_model.pickle", ""))) + except: + print("fail!") + pass + print("tried saving model took %fs"%(time.time()-t0)) + save_time, t0 = time.time()-t0, time.time() + + + + if len(states) < 50: print(len(states)) + + for j, state in enumerate(states): + print(". ", end="", flush=True) + if j == 0: _t0 =time.time() + (val, (vals, E_xc, density_matrix, _W)), grad = vandg(w, state, opts.normal, opts.nn, cfg, opts) + print(",", end="", flush=True) + if j == 0: time_step1 = time.time()-_t0 + + # todo: have hyper parameter that accumulates gradient or takes step? + w, adam_state = update(w, adam_state, grad) + + # todo: rename + global_batch_size = len(states)*opts.bs + if opts.wandb: dct["global_batch_size"] = global_batch_size + + train_time, t0 = time.time()-t0, time.time() + + # plot grad norm + #if iteration % 10 == 0: + # for k,v in accumulated_grad.items(): dct[k + "_norm"] = np.linalg.norm(v .reshape(-1) ) + update_time, t0 = time.time()-t0, time.time() + + if not opts.nn: + str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(2)]) + " [eV]" + str += "pyscf=%.7f us=%.7f"%(state.pyscf_E[0]/HARTREE_TO_EV, vals[0]) + else: + pbar.set_description("train=".join(["%.2f"%i for i in vals[:1]]) + "[Ha] "+ valid_str + "time=%.1f %.1f %.1f %.1f %.1f %.1f"%(load_time, time_step1, train_time, update_time, val_time, plot_time)) + + if opts.wandb: + dct["time_load"] = load_time + dct["time_step1"] = time_step1 + dct["time_train"] = train_time + dct["time_val"] = val_time + plot_iteration = iteration % 10 == 0 + for i in range(0, 2): + if not opts.nn: + dct['train_l%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) + dct['train_pyscf%i'%i ] = np.abs(state.pyscf_E[i]) + dct['train_E%i'%i ] = np.abs(vals[i]*HARTREE_TO_EV) + if plot_iteration: + dct['img/dm%i'%i] = wandb.Image(np.expand_dims(density_matrix[i], axis=-1)) + dct['img/W%i'%i] = wandb.Image(np.expand_dims(_W[i], axis=-1)) + + step = adam_state[1].count + + plot_time, t0 = time.time()-t0, time.time() + + + + # TODO: Plot molecules and val/ext angles. + if opts.nn and (iteration < 250 or iteration % 10 == 0): + + if val_state is None: val_state = jax.device_put(val_qm9[0]) + _, (valid_vals, _, vdensity_matrix, vW) = valf(w, val_state, opts.normal, opts.nn, cfg, opts) + if ext_state is None: ext_state = jax.device_put(ext_qm9[0]) + _, (ext_vals, _, edensity_matrix, eW) = valf(w, ext_state, opts.normal, opts.nn, cfg, opts) + + lr = custom_schedule(step) + valid_str = "lr=%.3e"%lr + "val=" + "".join(["%.4f "%(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + valid_str += "ext=" + "".join(["%.4f "%(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) for i in range(0, 3)]) + " [eV]" + if opts.wandb: + for i in range(0, opts.val_bs): + dct['valid_l%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV-val_state.pyscf_E[i]) + dct['valid_E%i'%i ] = np.abs(valid_vals[i]*HARTREE_TO_EV) + dct['valid_pyscf%i'%i ] = np.abs(val_state.pyscf_E[i]) + dct['img/val_dm%i'%i] = wandb.Image(np.expand_dims(vdensity_matrix[i], axis=-1)) + dct['img/val_W%i'%i] = wandb.Image(np.expand_dims(vW[i], axis=-1)) + + dct['ext_l%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV-ext_state.pyscf_E[i]) + dct['ext_E%i'%i ] = np.abs(ext_vals[i]*HARTREE_TO_EV) + dct['ext_pyscf%i'%i ] = np.abs(ext_state.pyscf_E[i]) + dct['img/ext_dm%i'%i] = wandb.Image(np.expand_dims(edensity_matrix[i], axis=-1)) + dct['img/ext_W%i'%i] = wandb.Image(np.expand_dims(eW[i], axis=-1)) + + dct["scheduled_lr"] = lr + + + if opts.wandb: + dct["step"] = step + wandb.log(dct) + val_time, t0 = time.time()-t0, time.time() + + val, density_matrix = min_val, min_dm + + exit() + # needs batching + V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) + V_xc = (V_xc + V_xc.T)/2 + diff_JK = get_JK(density_matrix, state.ERI) + H = state.H_core + diff_JK + V_xc + mo_energy, mo_coeff = np.linalg.eigh(state.L_inv @ H @ state.L_inv.T) + mo_coeff = state.L_inv.T @ mo_coeff + + return val, (0, mo_energy, mo_coeff, state.grid_coords, state.grid_weights, density_matrix, H) + + +import chex +@chex.dataclass +class IterationState: + mask: np.array + init: np.array + E_nuc: np.array + L_inv: np.array + L_inv_T: np.array + H_core: np.array + grid_AO: np.array + grid_weights: np.array + grid_coords: np.array + pyscf_E: np.array + N: int + ERI: np.array + nonzero_distinct_ERI: list + nonzero_indices: list + diffs_ERI: np.array + main_grid_AO: np.array + diffs_grid_AO: np.array + indxs: np.array + sparse_diffs_grid_AO: np.array + rows: np.array + cols: np.array + pos: np.array + ao_types: np.array + pad_sizes: np.array + precomputed_nonzero_indices: np.array + precomputed_indxs: np.array + +from pyscf.data.elements import charge as elements_proton +from pyscf.dft import gen_grid, radi + +def treutler_atomic_radii_adjust(mol, atomic_radii): + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + a = jnp.array(a) + + def fadjust(i, j, g): + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j] + g1 += g + return g1 + + return fadjust + + +def inter_distance(coords): + rr = np.linalg.norm(coords.reshape(-1, 1, 3) - coords, axis=2) + rr[np.diag_indices(rr.shape[0])] = 0 + return rr + +def original_becke(g): + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + return g + +def gen_grid_partition(coords, atom_coords, natm, atm_dist, elements, + atomic_radii, becke_scheme=original_becke,): + ngrids = coords.shape[0] + dc = coords[None] - atom_coords[:, None] + grid_dist = np.sqrt(np.einsum('ijk,ijk->ij', dc, dc)) # [natom, ngrid] + + ix, jx = np.tril_indices(natm, k=-1) + + natm, ngrid = grid_dist.shape + #g_ = -1 / atm_dist.reshape(natm, natm, 1) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) + g_ = -1 / (atm_dist.reshape(natm, natm, 1) + np.eye(natm).reshape(natm, natm,1)) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) + #g_ = jnp.array(g_) + + def pbecke_g(i, j): + g = g_[i, j] + charges = [elements_proton(x) for x in elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j].reshape(-1, 1) + g1 += g + return g1 + + g = pbecke_g(ix, jx) + g = np.copy(becke_scheme(g)) + gp2 = (1+g)/2 + gm2 = (1-g)/2 + + t0 = time.time() + #pbecke = f(gm2, gp2, natm, ngrids, ix, jx ) + pbecke = np.ones((natm, ngrids)) + c = 0 + # this goes up to n choose two + for i in range(natm): + for j in range(i): + pbecke[i] *= gm2[c] + pbecke[j] *= gp2[c] + c += 1 + #print("\t", time.time()-t0) + return pbecke + + +def get_partition( + mol, + atom_coords, + atom_grids_tab, + radii_adjust=treutler_atomic_radii_adjust, + atomic_radii=radi.BRAGG_RADII, + becke_scheme=original_becke, + concat=True, state=None +): + t0 = time.time() + atm_dist = inter_distance(atom_coords) # [natom, natom] + + coords_all = [] + weights_all = [] + + # [ ] consider another grid? + for ia in range(mol.natm): + coords, vol = atom_grids_tab[mol.atom_symbol(ia)] + coords = coords + atom_coords[ia] # [ngrid, 3] + pbecke = gen_grid_partition(coords, atom_coords, mol.natm, atm_dist, mol.elements, atomic_radii) # [natom, ngrid] + weights = vol * pbecke[ia] / np.sum(pbecke, axis=0) + coords_all.append(coords) + weights_all.append(weights) + + if concat: + coords_all = np.vstack(coords_all) + weights_all = np.hstack(weights_all) + + coords = (coords_all, weights_all) + return coords_all, weights_all + + +class DifferentiableGrids(gen_grid.Grids): + """Differentiable alternative to the original pyscf.gen_grid.Grids.""" + + def build(self, atom_coords, state=None) : + t0 = time.time() + mol = self.mol + + atom_grids_tab = self.gen_atomic_grids( + mol, self.atom_grid, self.radi_method, self.level, self.prune + ) + + coords, weights = get_partition( + mol, + atom_coords, + atom_grids_tab, + treutler_atomic_radii_adjust, + self.atomic_radii, + original_becke, + state=state, + ) + + self.coords = coords + self.weights = weights + return coords, weights + + +def grids_from_pyscf_mol( + mol: pyscf.gto.mole.Mole, quad_level: int = 1 +) : + g = gen_grid.Grids(mol) + g.level = quad_level + g.build() + grids = jnp.array(g.coords) + weights = jnp.array(g.weights) + return grids, weights + + +def init_dft(mol_str, opts, _coords=None, _weights=None, first=False, do_pyscf=True, state=None, pad_electrons=-1): + #t0 = time.time() + mol = build_mol(mol_str, opts.basis) + if do_pyscf: pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + else: pyscf_E, pyscf_hlgap, pyscf_forces = np.zeros(1), np.zeros(1), np.zeros(1) + + N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) + n_electrons_half = mol.nelectron//2 # 21 for C6H6 + E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. + + from pyscf import dft + #grids = pyscf.dft.gen_grid.Grids(mol) + grids = DifferentiableGrids(mol) + grids.level = opts.level + #grids.build() + grids.build(np.concatenate([np.array(a[1]).reshape(1, 3) for a in mol._atom]), state=state) + + grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6 + grid_coords = grids.coords + coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' + grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. + + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. + kinetic = mol.intor_symmetric('int1e_kin') # (N,N) + nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) + O = mol.intor_symmetric('int1e_ovlp') # (N,N) + L = np.linalg.cholesky(O) + L_inv = np.linalg.inv(L) # (N,N) + + if pad_electrons == -1: + init = np.eye(N)[:, :n_electrons_half] + mask = np.ones((1, n_electrons_half)) + else: + assert pad_electrons > n_electrons_half, (pad_electrons, n_electrons_half) + init = np.eye(N)[:, :pad_electrons] + mask = np.zeros((1, pad_electrons)) + mask[:, :n_electrons_half] = 1 + + if opts.normal: + ERI = mol.intor("int2e_sph") + nonzero_distinct_ERI = np.zeros(1) + nonzero_indices = np.zeros(1) + else: + eri_threshold = 0 + batches = 1 + nipu = 1 + nonzero_distinct_ERI = mol.intor("int2e_sph", aosym="s8") + #ERI = [nonzero_distinct_ERI, nonzero_indices] + #ERI = ERI + ERI = np.zeros(1) + #ERI = mol.intor("int2e_sph") + + def e(x): return np.expand_dims(x, axis=0) + + n_C = nao('C', opts.basis) + n_N = nao('N', opts.basis) + n_O = nao('O', opts.basis) + n_F = nao('F', opts.basis) + n_H = nao('H', opts.basis) + n_vocab = n_C + n_N + n_O + n_F + n_H + start, stop = 0, n_C + c = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_N + n = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_O + o = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_F + f = list(range(n_vocab))[start:stop] + start, stop = stop, stop+n_H + h = list(range(n_vocab))[start:stop] + types = [] + pos = [] + for a, p in mol_str: + if a.lower() == 'h': + types += h + pos += [np.array(p).reshape(1, -1)]*len(h) + elif a.lower() == 'c': + types += c + pos += [np.array(p).reshape(1, -1)]*len(c) + elif a.lower() == 'n': + types += n + pos += [np.array(p).reshape(1, -1)]*len(n) + elif a.lower() == 'o': + types += o + pos += [np.array(p).reshape(1, -1)]*len(o) + elif a.lower() == 'f': + types += f + pos += [np.array(p).reshape(1, -1)]*len(f) + else: raise Exception() + ao_types = np.array(types) + pos = np.concatenate(pos) + pad_sizes = np.zeros(1) + + state = IterationState( + diffs_ERI = np.zeros((1,1)), + main_grid_AO = np.zeros((1,1)), + diffs_grid_AO = np.zeros((1,1)), + indxs = np.zeros((1,1)), + sparse_diffs_grid_AO = np.zeros((1,1)), + rows = np.zeros((1,1)), + cols = np.zeros((1,1)), + pos=e(pos), + ao_types=e(ao_types), + init = e(init), + E_nuc=e(E_nuc), + ERI=e(ERI), + nonzero_distinct_ERI=[nonzero_distinct_ERI], + nonzero_indices=[0], + H_core=e(nuclear+kinetic), + L_inv=e(L_inv), + L_inv_T = e(L_inv.T), + grid_AO=e(grid_AO), + grid_weights=e(grid_weights), + grid_coords=e(grid_coords), + pyscf_E=e(pyscf_E[-1:]), + N=e(mol.nao_nr()), + mask=e(mask), + pad_sizes=e(pad_sizes), + precomputed_nonzero_indices=np.zeros((1,1)), + precomputed_indxs=np.zeros((1,1)), + ) + + + return state + + +def summary(state): + if state is None: return + print("_"*100) + total = 0 + for field_name, field_def in state.__dataclass_fields__.items(): + field_value = getattr(state, field_name) + try: + print("%35s %24s %20s"%(field_name,getattr(field_value, 'shape', None), getattr(field_value, "nbytes", None)/10**9)) + total += getattr(field_value, "nbytes", None)/10**9 + + except: + try: + print("%35s %25s %20s"%(field_name,getattr(field_value[0], 'shape', None), getattr(field_value[0], "nbytes", None)/10**9)) + total += getattr(field_value, "nbytes", None)/10**9 + except: + print("BROKE FOR ", field_name) + + print("%35s %25s %20s"%("-", "total", total)) + try: + print(state.pyscf_E[:, -1]) + except: + pass + print("_"*100) + +def _cat(x,y,name): + if "list" in str(type(x)): + return x + y + else: + return np.concatenate([x,y]) + + +def cat(dc1, dc2, axis=0): + # Use dictionary comprehension to iterate over the dataclass fields + concatenated_fields = { + field: _cat(getattr(dc1, field), getattr(dc2, field), field) + for field in dc1.__annotations__ + } + # Create a new dataclass instance with the concatenated fields + return IterationState(**concatenated_fields) + +def _cats(xs): + if "list" in str(type(xs[0])): + return sum(xs, [])#x + y + else: + return np.concatenate(xs) + + +def cats(dcs): + concatenated_fields = { + field: _cats([getattr(dc, field) for dc in dcs]) + for field in dcs[0].__annotations__ + } + # Create a new dataclass instance with the concatenated fields + return IterationState(**concatenated_fields) + +def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): + # Electronic part of RHF/RKS gradients + dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + dme0 = 2 * (mo_coeff * mask*mo_energy) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + + # Code identical to exchange correlation. + rho = jnp.sum( grid_AO[:1] @ dm0 * grid_AO, axis=2) # (10, grid_size) = (10, 45624) for C6H6. + _, vrho, vgamma = vxc_b3lyp(rho, EPSILON_B3LYP) # (grid_size,) (grid_size,) + V_xc = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma.reshape(1, -1)*rho[1:4]], axis=0) # (4, grid_size) + + vmat = grid_AO[1:4].transpose(0, 2, 1) @ jnp.sum(grid_AO[:4] * jnp.expand_dims(weight * V_xc, axis=2), axis=0) # (3, N, N) + aos = jnp.concatenate([jnp.expand_dims(grid_AO[np.array([1,4,5,6])], 0), jnp.expand_dims(grid_AO[np.array([2,5,7,8])], 0), jnp.expand_dims(grid_AO[np.array([3,6,8,9])], 0)], axis=0) # (3, N, N) + V_xc = - vmat - jnp.transpose(jnp.einsum("snpi,np->spi", aos, weight*V_xc), axes=(0,2,1)) @ grid_AO[0] # (3, 4, grid_size, N) + + vj = - jnp.einsum('sijkl,lk->sij', eri, dm0) # (3, N, N) + vk = - jnp.einsum('sijkl,jk->sil', eri, dm0) # (3, N, N) + vhf = V_xc + vj - vk * .5 * HYB_B3LYP # (3, N, N) + + de = jnp.einsum('lxij,ij->lx', h1aos, dm0) # (natm, 3) + for k, ia in enumerate(range(natm)): + p0, p1 = aoslices[ia][2], aoslices[ia][3] + de = de.at[k].add(jnp.einsum('xij,ij->x', vhf[:, p0:p1], dm0[p0:p1]) * 2) + de = de.at[k].add(-jnp.einsum('xij,ij->x', s1[:, p0:p1], dme0[p0:p1]) * 2) + return de + +def grad_nuc(charges, coords): + # Derivatives of nuclear repulsion energy wrt nuclear coordinates + natm = charges.shape[0] + pairwise_charges = charges.reshape(natm, 1) * charges.reshape(1, natm) # (natm, natm) + pairwise_difference = coords.reshape(1, natm, 3) - coords.reshape(natm, 1, 3) # (natm, natm, 3) + pairwise_distances = jnp.linalg.norm(pairwise_difference, axis=2) ** 3 # (natm, natm) + pairwise_distances = jnp.where(pairwise_distances == 0, jnp.inf, pairwise_distances) # (natm, natm) + all = - pairwise_charges.reshape(natm, natm, 1) * pairwise_difference # (natm, natm, 3) + all = all / pairwise_distances.reshape(natm, natm, 1) # (natm, natm, 3) + all = all.at[jnp.diag_indices(natm)].set(0) # (natm, natm, 3) + return jnp.sum(all, axis=0) # (natm, natm) + +def grad(mol, coords, weight, mo_coeff, mo_energy, dm, H): + # Initialize DFT tensors on CPU using PySCF. + ao = pyscf.dft.numint.NumInt().eval_ao(mol, coords, deriv=2) + eri = mol.intor("int2e_ip1") + s1 = - mol.intor('int1e_ipovlp', comp=3) + kin = - mol.intor('int1e_ipkin', comp=3) + nuc = - mol.intor('int1e_ipnuc', comp=3) + + aoslices = mol.aoslice_by_atom() + h1 = kin + nuc + def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> + _, _, p0, p1 = aoslices[atm_id] + with mol.with_rinv_at_nucleus(atm_id): + vrinv = mol.intor('int1e_iprinv', comp=3) # + vrinv *= -mol.atom_charge(atm_id) + vrinv[:,p0:p1] += h1[:,p0:p1] + return vrinv + vrinv.transpose(0,2,1) + N = h1.shape[1] # (3, N , N) + h1aos = np.zeros((mol.natm, 3, N, N)) + for k, ia in enumerate(range(mol.natm)): + p0, p1 = aoslices[ia,2:] + h1aos[k] = hcore_deriv(ia, aoslices, h1) + + charges = np.zeros((mol.natm)) + coords = np.zeros((mol.natm,3)) + for j in range(mol.natm): + charges[j] = mol.atom_charge(j) + coords[j]= mol.atom_coord(j) + + #_grad_elec = jax.jit(grad_elec, static_argnames=["aoslices", "natm"], backend="cpu") + _grad_elec = grad_elec + _grad_nuc = jax.jit(grad_nuc, backend="cpu") + + return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff, mol, dm, H) + _grad_nuc(charges, coords) + +def pyscf_reference(mol_str, opts): + from pyscf import __config__ + __config__.dft_rks_RKS_grids_level = opts.level + mol = build_mol(mol_str, opts.basis) + mol.max_cycle = 50 + mf = pyscf.scf.RKS(mol) + mf.max_cycle = 50 + mf.xc = "b3lyp5" + mf.diis_space = 8 + pyscf_energies = [] + pyscf_hlgaps = [] + lumo = mol.nelectron//2 + homo = lumo - 1 + t0 = time.time() + def callback(envs): + pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) + hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV + pyscf_hlgaps.append(hl_gap_hartree) + print("PYSCF: ", pyscf_energies[-1], "[eV]", time.time()-t0) + mf.callback = callback + mf.kernel() + print("") + if False: + forces = mf.nuc_grad_method().kernel() + else: forces = 0 + return np.array(pyscf_energies), np.array(pyscf_hlgaps), np.array(forces) + +def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap): + #TODO(HH): rename to match caller variable names + nanoDFT_E = nanoDFT_E*HARTREE_TO_EV + print("pyscf:\t\t%15f"%pyscf_E[-1]) + print("us:\t\t%15f"%nanoDFT_E) + print("diff:\t\t%15f"%np.abs(pyscf_E[-1]-nanoDFT_E)) + print("chemAcc: \t%15f"%0.043) + print("chemAcc/diff: \t%15f"%(0.043/np.abs(pyscf_E[-1]-nanoDFT_E))) + print("") + + # Forces + print() + print("np.max(|nanoDFT_F-PySCF_F|):", np.max(np.abs(nanoDFT_forces-pyscf_forces))) + + norm_X = np.linalg.norm(nanoDFT_forces, axis=1) + norm_Y = np.linalg.norm(pyscf_forces, axis=1) + dot_products = np.sum(nanoDFT_forces * pyscf_forces, axis=1) + cosine_similarity = dot_products / (norm_X * norm_Y) + print("Force cosine similarity:",cosine_similarity) + +def build_mol(mol_str, basis_name): + mol = pyscf.gto.mole.Mole() + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + return mol + +def reference(mol_str, opts): + import pickle + import hashlib + if opts.skip: return np.zeros(1), np.zeros(1), np.zeros(1) + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + print(filename) + if not os.path.exists(filename): + pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) + with open(filename, "wb") as file: + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + else: + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + return pyscf_E, pyscf_hlgap, pyscf_forces + + +class HashableNamespace: + def __init__(self, namespace): + self.__dict__.update(namespace.__dict__) + + def __hash__(self): + # Convert the relevant attributes to a tuple for hashing + return hash(tuple(sorted(self.__dict__.items()))) + + +if __name__ == "__main__": + import os + import argparse + + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-steps', type=int, default=100000) + parser.add_argument('-bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + + # energy computation speedups + parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) + parser.add_argument('-xc_f32', action="store_true") + parser.add_argument('-eri_f32', action="store_true") + parser.add_argument('-eri_bs', type=int, default=8) + + parser.add_argument('-normal', action="store_true") + parser.add_argument('-wandb', action="store_true") + parser.add_argument('-prof', action="store_true") + parser.add_argument('-visualize', action="store_true") + parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + + # dataset + parser.add_argument('-qm9', action="store_true") + parser.add_argument('-benzene', action="store_true") + parser.add_argument('-hydrogens', action="store_true") + parser.add_argument('-water', action="store_true") + parser.add_argument('-waters', action="store_true") + parser.add_argument('-alanine', action="store_true") + parser.add_argument('-states', type=int, default=1) + parser.add_argument('-workers', type=int, default=5) + parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase + parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") + parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") + parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + + # models + parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") + parser.add_argument('-tiny', action="store_true") + parser.add_argument('-small', action="store_true") + parser.add_argument('-base', action="store_true") + parser.add_argument('-medium', action="store_true") + parser.add_argument('-large', action="store_true") + parser.add_argument('-xlarge', action="store_true") + + parser.add_argument("-checkpoint", default=-1, type=int, help="which iteration to save model (default -1 = no saving)") # checkpoint model + parser.add_argument("-resume", default="", help="path to checkpoint pickle file") # checkpoint model + opts = parser.parse_args() + + if opts.tiny or opts.small or opts.base or opts.large or opts.xlarge: opts.nn = True + + args_dict = vars(opts) + print(args_dict) + + if opts.qm9: + df = pd.read_pickle("alchemy/atom_9.pickle") + df = df[df["spin"] == 0] # only consider spin=0 + mol_strs = df["pyscf"].values + + # benzene + if opts.benzene: + mol_strs = [[ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ]] + # hydrogens + if opts.hydrogens: + mol_strs = [[ + ["H", ( 0.0000, 0.0000, 0.0000)], + ["H", ( 1.4000, 0.0000, 0.0000)], + ]] + if opts.water: + mol_strs = [[ + ["O", ( 0.0000, 0.0000, 0.0000)], + ["H", ( 0.0000, 1.4000, 0.0000)], + ["H", ( 1.4000, 0.0000, 0.0000)], + ]] + if opts.waters: + mol_strs = [[ + ["O", (-1.464, 0.099, 0.300)], + ["H", (-1.956, 0.624, -0.340)], + ["H", (-1.797, -0.799, 0.206)], + ["O", ( 1.369, 0.146, -0.395)], + ["H", ( 1.894, 0.486, 0.335)], + ["H", ( 0.451, 0.165, -0.083)]]] + + elif opts.alanine: + mol_strs = [[ # 22 atoms (12 hydrogens) => 10 heavy atoms (i.e. larger than QM9). + ["H", ( 2.000 , 1.000, -0.000)], + ["C", ( 2.000 , 2.090, 0.000)], + ["H", ( 1.486 , 2.454, 0.890)], + ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], + ["O", ( 4.391 , 1.877, -0.000)], + ["N", ( 3.555 , 3.970, -0.000)], + ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], # carbon alpha + ["H", ( 5.408 , 4.316, 0.890)], # hydrogne attached to carbon alpha + ["C", ( 5.661 , 4.221, -1.232)], # carbon beta + ["H", ( 5.123 , 4.521, -2.131)], # hydrogens attached to carbon beta + ["H", ( 6.630 , 4.719, -1.206)], # hydrogens attached to carbon beta + ["H", ( 5.809 , 3.141, -1.241)], # hydrogens attached to carbon beta + ["C", ( 4.713 , 6.129, 0.000)], + ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], + ["H", ( 6.737 , 6.359, -0.000)], + ["C", ( 5.846 , 8.284, 0.000)], + ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], + ["H", ( 6.360 , 8.648, -0.890)], + ]] + + # make opts hashable so that JAX will not complain about the static parameter that is passed as arg + opts = HashableNamespace(opts) + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_strs, opts) + + exit() + pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) + nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) diff --git a/pyscf_ipu/direct/transformer.py b/pyscf_ipu/direct/transformer.py new file mode 100644 index 00000000..9cc2b9d0 --- /dev/null +++ b/pyscf_ipu/direct/transformer.py @@ -0,0 +1,340 @@ +""" Pure-from-the-ground-up transformer, based on https://github.com/vpj/jax_transformer/blob/master/transformer.py """ +import jax +from jax import vmap +import jax.numpy as jnp +from functools import partial +import jax.experimental.host_callback +import math +import numpy as np + +def rand(rng, f, shape, **kwargs): + rng, rng1 = jax.random.split(rng) + return rng, f(rng1, shape, **kwargs) + +def linear_init_uniform(rng: jax.random.KeyArray, in_features: int, out_features: int): + # todo: init as kaparthy + params = ParamsDict() + rnd_range = 1 / in_features**0.5 + rng, params.weight = rand( rng, jax.random.uniform, (in_features, out_features), minval=-rnd_range, maxval=rnd_range,) + params.bias = jnp.zeros((out_features,)) + return rng, params, (in_features, out_features) + +def elementwise_linear_init_identity(shape): return ParamsDict(gain=jnp.ones(shape), bias=jnp.zeros(shape)) + +def linear(params, x: jnp.ndarray): return x @ params.weight + params.bias[None, :] + +def elementwise_linear(params, x: jnp.ndarray): return params.gain[None, :] * x + params.bias[None, :] + +def standardize(x, eps=1e-5): return (x - x.mean()) / (x.std() + eps) + +def transformer_init( + rng: jax.random.KeyArray, + n_vocab: int, + d_model: int, + n_layers: int, + n_heads: int, + d_ff: int, + max_len=4096, +): + total_params = 0 + + # Build config struct for call + config = ParamsDict() + config.heads = n_heads + if True: #flip_pe_coef(): + config.lambda_e = d_model**-0.5 + config.lambda_pe = 1.0 + else: + config.lambda_e = d_model**-0.5 + config.lambda_pe = 1.0 + + # Build initializers for params + params = ParamsDict() + + print("_"*100) + + # Create embedding layer + rng, params.embeddings = rand(rng, jax.random.normal, (n_vocab, d_model)) + total_params += np.prod(params.embeddings.shape) + print("%26s %26s %26s"%("params.embeddings",params.embeddings.shape, np.prod(params.embeddings.shape))) + + rng, params.project_positions, shape = linear_init_uniform(rng, 12, d_model) + total_params += np.prod(shape) + print("%26s %26s %26s"%("params.project_positions",shape, np.prod(shape))) + + # For transformer layers + params.layers = [] + for i in range(n_layers): + layer = ParamsDict() + layer.norm_self_attn = elementwise_linear_init_identity(d_model) + total_params += np.prod(d_model*2) + print("%26s %26s %26s"%("layer%i.norm_self_attn"%i, (d_model,2), np.prod((d_model, 2)))) + + rng, layer.kqv, shape = linear_init_uniform(rng, d_model, d_model*3) + total_params += np.prod(shape) # omitting bias in calculation for now + print("%26s %26s %26s"%("layer%i.kqv"%i, shape, np.prod(shape))) + + layer.norm_ff = elementwise_linear_init_identity(d_model) + total_params += np.prod(d_model*2) + print("%26s %26s %26s"%("layer%i.norm_ff"%i, (d_model,2), np.prod((d_model, 2)))) + + rng, layer.ffn1, shape = linear_init_uniform(rng, d_model, d_ff) + total_params += np.prod(shape) + print("%26s %26s %26s"%("layer%i.ffn1"%i, shape, np.prod(shape))) + + rng, layer.ffn2, shape = linear_init_uniform(rng, d_ff, d_model) + total_params += np.prod(shape) + print("%26s %26s %26s"%("layer%i.ffn2"%i, shape, np.prod(shape))) + + params.layers.append(layer) + + # Final normalization and output layer + print("total: ", total_params) + + return rng, config, params, total_params + + +@partial(jax.jit, static_argnums=0) +def transformer(cfg, params, x: jnp.ndarray, position: jnp.ndarray, H_core: jnp.ndarray): + """ + cfg: Config, from transformer_init, holds hyperparameters + params: Current transformer parameters, initialized in init + x: 1D array of L integers, representing the input sequence + output: L x n_vocab logits + """ + L, = x.shape # x is just 1D. Vmap/pmap will handle batching + + embeddings = cfg.lambda_e * params.embeddings[x, :] # L x Dm + + all_pairs = jnp.linalg.norm(position.reshape(1, -1, 3) - position.reshape(-1, 1, 3), axis=-1) + + # inspired by 3d point cloud transformers; + # nspired by andrew: use trigonometric functions as feature transformations + position = jnp.concatenate([position, jnp.cos(position), jnp.sin(position), jnp.tanh(position)], axis=1) #(N,3) -> (N,12) + positions = linear(params.project_positions, position) # L x Dm + + # Add (learned) positional encodings + x = embeddings + positions # L x Dm + L, Dm = x.shape + nheads = cfg.heads + + def block(x, layer_num, layer): + # Layer-normalize + t1 = vmap(standardize)(x) # L x Dm + t1 = elementwise_linear(layer.norm_self_attn, t1) # L x Dm + + qkv = linear(layer.kqv, t1) + q,k,v = jnp.split(qkv, 3, axis=1) + q = jnp.transpose(q.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + k = jnp.transpose(k.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + v = jnp.transpose(v.reshape(L, nheads, Dm//nheads), (1, 0, 2)) + score = (q @ jnp.transpose(k, (0, 2, 1))) / math.sqrt(Dm) + + # do like graphformer and append position here? + #if layer_num < 6: # doesn't look like it helps + # score += H_core + # score += all_pairs + + attn = jax.nn.softmax(score , axis=1) + x = x + (attn @ v).reshape(L, Dm) + + # Layer-normalize + t2 = vmap(standardize)(x) + t2 = elementwise_linear(layer.norm_ff, t2) # L x Dm + + # Feedforward fully connected + t2 = linear(layer.ffn1, t2) # L x Dm*4 + t2 = jax.nn.gelu(t2) + t2 = linear(layer.ffn2, t2) # L x Dm + + # Residual connection + x = x + t2 + return x, score + + # Apply the transformer layers + # todo: cut jit time by making this jax.lax.foriloop + for layer_num, layer in enumerate(params.layers): + x, score = jax.checkpoint(block)(x, layer_num, layer) + + # todo: if this isn't symmetric eigh gives imaginary eigenvalues? (bad) + M = score[0] # take first attention head + #M = (M + M.T)/2 # make symmetric! + return M + +import types +import json +import jax + +import numbers + +def is_simple_type(x): + return isinstance(x, (numbers.Number, bool, str)) + +@jax.tree_util.register_pytree_node_class +class ParamsDict(types.SimpleNamespace): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def tree_flatten(self): + return jax.tree_util.tree_flatten(self.__dict__, lambda a: a is not self.__dict__) # only flatten one step + + @classmethod + def tree_unflatten(cls, aux, values): + return ParamsDict(**jax.tree_util.tree_unflatten(aux, values)) + + def toJSON(self): + return json.dumps(self, default=lambda o: o.__dict__, + sort_keys=True, indent=4) + + def __hash__(self): + # Should overload setattr to warn if setattr is called after hash has been computed + return hash(tuple(hash(x) for (_,x) in self.__dict__.items())) + + def print(self, path = ''): + for (k,v) in self.items(path): + print(k + ':',v) + + @classmethod + def labels_aux(cls, path, obj): + if isinstance(obj, (list, tuple)) and any(not is_simple_type(x) for x in obj): + for i,vi in enumerate(obj): + yield from cls.labels_aux(f'{path}[{i}]', vi) + elif isinstance(obj, dict): + for (k,v) in obj.items(): + yield from cls.labels_aux(path + '/' + k, v) + elif isinstance(obj, ParamsDict): + yield from cls.labels_aux(path, obj.__dict__) + else: + yield (path, obj) + + def items(self, path = ''): + yield from self.labels_aux(path, self) + + def to_float32(self): + def convert_to_float32(x): + if isinstance(x, jnp.ndarray) and x.dtype == jnp.float64: + return x.astype(jnp.float32) + return x + + # Create a new ParamsDict instance with converted arrays + new_dict = jax.tree_map(convert_to_float32, self.__dict__) + return ParamsDict(**new_dict) + self.__dict__ = jax.tree_map(convert_to_float32, self.__dict__) + + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=2.5e-4) + parser.add_argument('-steps', type=int, default=100000) + parser.add_argument('-bs', type=int, default=8) + parser.add_argument('-val_bs', type=int, default=8) + parser.add_argument('-mol_repeats', type=int, default=16) # How many time to optimize wrt each molecule. + + # energy computation speedups + parser.add_argument('-foriloop', action="store_true") # whether to use jax.lax.foriloop for sparse_symmetric_eri (faster compile time but slower training. ) + parser.add_argument('-xc_f32', action="store_true") + parser.add_argument('-eri_f32', action="store_true") + parser.add_argument('-eri_bs', type=int, default=8) + + parser.add_argument('-normal', action="store_true") + parser.add_argument('-wandb', action="store_true") + parser.add_argument('-prof', action="store_true") + parser.add_argument('-visualize', action="store_true") + parser.add_argument('-skip', action="store_true", help="skip pyscf test case") + + # dataset + parser.add_argument('-qm9', action="store_true") + parser.add_argument('-benzene', action="store_true") + parser.add_argument('-hydrogens', action="store_true") + parser.add_argument('-water', action="store_true") + parser.add_argument('-waters', action="store_true") + parser.add_argument('-alanine', action="store_true") + parser.add_argument('-states', type=int, default=1) + parser.add_argument('-workers', type=int, default=5) + parser.add_argument('-precompute', action="store_true") # precompute labels; only run once for data{set/augmentation}. + # do noise schedule, start small slowly increase + parser.add_argument('-wiggle_var', type=float, default=0.05, help="wiggle N(0, wiggle_var), bondlength=1.5/30") + parser.add_argument('-eri_threshold', type=float, default=1e-10, help="loss function threshold only") + parser.add_argument('-rotate_deg', type=float, default=90, help="how many degrees to rotate") + + # models + parser.add_argument('-nn', action="store_true", help="train nn, defaults to GD") + parser.add_argument('-tiny', action="store_true") + parser.add_argument('-small', action="store_true") + parser.add_argument('-base', action="store_true") + parser.add_argument('-medium', action="store_true") + parser.add_argument('-large', action="store_true") + parser.add_argument('-xlarge', action="store_true") + opts = parser.parse_args() + + # initialize model + # transformer tiny 5M + d_model= 192 + n_heads = 6 + n_layers = 12 + + from train import nao + rnd_key = jax.random.PRNGKey(42) + n_vocab = nao("C", opts.basis) + nao("N", opts.basis) + \ + nao("O", opts.basis) + nao("F", opts.basis) + \ + nao("H", opts.basis) + + rnd_key, cfg, params, total_params = transformer_init( + rnd_key, + n_vocab, + d_model =d_model, + n_layers=n_layers, + n_heads =n_heads, + d_ff =d_model*4, + ) + + + # compute dummy output + from train import batched_state, summary + opts.alanine = True + alanine = [ ["H", ( 2.000 , 1.000, -0.000)], ["C", ( 2.000 , 2.090, 0.000)], ["H", ( 1.486 , 2.454, 0.890)], ["H", ( 1.486 , 2.454, -0.890)], + ["C", ( 3.427 , 2.641, -0.000)], ["O", ( 4.391 , 1.877, -0.000)], ["N", ( 3.555 , 3.970, -0.000)], ["H", ( 2.733 , 4.556, -0.000)], + ["C", ( 4.853 , 4.614, -0.000)], ["H", ( 5.408 , 4.316, 0.890)], ["C", ( 5.661 , 4.221, -1.232)], ["H", ( 5.123 , 4.521, -2.131)], + ["H", ( 6.630 , 4.719, -1.206)], ["H", ( 5.809 , 3.141, -1.241)], ["C", ( 4.713 , 6.129, 0.000)], ["O", ( 3.601 , 6.653, 0.000)], + ["N", ( 5.846 , 6.835, 0.000)], ["H", ( 6.737 , 6.359, -0.000)], ["C", ( 5.846 , 8.284, 0.000)], ["H", ( 4.819 , 8.648, 0.000)], + ["H", ( 6.360 , 8.648, 0.890)], ["H", ( 6.360 , 8.648, -0.890)], ] + state = batched_state(alanine, opts, opts.bs, \ + wiggle_num=0, do_pyscf=False, validation=False, \ + extrapolate=False, mol_idx=0) + summary(state) + + output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), + static_argnums=(0,), + backend="cpu")(cfg, \ + params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + + + print(np.sum(output)) # 162.58726108305348 + + + # store model + import pickle + pickle.dump(params, open("checkpoints/example.pickle", "wb")) + + # reload model + new_params = pickle.load(open("checkpoints/example.pickle", "rb")) + + # check that output remains the same + new_output = jax.jit(jax.vmap(transformer, in_axes=(None, None, 0, 0, 0), out_axes=(0)), + static_argnums=(0,), + backend="cpu")(cfg, \ + new_params, state.ao_types, state.pos.astype(jnp.float32), state.H_core.astype(jnp.float32)) + + assert np.allclose(output, new_output) + print("TEST CASE PASSED!") + + + \ No newline at end of file diff --git a/pyscf_ipu/exchange_correlation/b3lyp.py b/pyscf_ipu/exchange_correlation/b3lyp.py index a5780ee7..731c8a43 100644 --- a/pyscf_ipu/exchange_correlation/b3lyp.py +++ b/pyscf_ipu/exchange_correlation/b3lyp.py @@ -14,6 +14,37 @@ def b3lyp(rho, EPSILON_B3LYP=0): rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) + rho0 = rho.T[:, 0] + #norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP + norms = jnp.linalg.norm(rho[1:]+CLIP_RHO_MIN, axis=0).T**2+EPSILON_B3LYP + + def lda(rho0): return jax.vmap(jax.value_and_grad(lambda x: __lda(x)*0.08)) (rho0) + def vwn(rho0): return jax.vmap(jax.value_and_grad(lambda x: __vwn(x)*0.19)) (rho0) + + # disabled gradient checkpointing + #def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__b88)(rho0, norm)*0.72, (0, 1))) (rho0, norms) + #def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: jax.checkpoint(__lyp)(rho0, norm)*0.810, (0, 1))) (rho0, norms) + + def b88(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __b88(rho0, norm)*0.72, (0,1)))(rho0, norms) + def lyp(rho0, norms): return jax.vmap(jax.value_and_grad(lambda rho0, norm: __lyp(rho0, norm)*0.810, (0,1)))(rho0, norms) + + e_xc_lda, v_rho_lda = jax.jit(lda)(rho0) + e_xc_vwn, v_rho_vwn = jax.jit(vwn)(rho0) + e_xc_b88, (v_rho_b88, v_norm_b88) = jax.jit(b88)(rho0, norms) + e_xc_lyp, (v_rho_lyp, v_norm_lyp) = jax.jit(lyp)(rho0, norms) + + e_xc = e_xc_lda + (e_xc_vwn + e_xc_b88 + e_xc_lyp) / rho0 + #v_xc_rho = v_rho_lda*4*rho0 + v_rho_vwn + v_rho_b88 + v_rho_lyp + #v_xc_norms = v_norm_b88 + v_norm_lyp + + return e_xc#, v_xc_rho, v_xc_norms + + + +def vxc_b3lyp(rho, EPSILON_B3LYP=0): + + rho = jnp.concatenate([jnp.clip(rho[:1], CLIP_RHO_MIN, CLIP_RHO_MAX), rho[1:4]*2]) + rho0 = rho.T[:, 0] norms = jnp.linalg.norm(rho[1:], axis=0).T**2+EPSILON_B3LYP @@ -156,4 +187,4 @@ def plot(rho, b, a, g, grad, vnorm=None, name=""): # b is pyscf a is us ax[2].set_yscale("log") ax[2].set_xscale("log") plt.tight_layout() - plt.savefig("%s_3.jpg"%name) + plt.savefig("%s_3.jpg"%name) \ No newline at end of file diff --git a/pyscf_ipu/nanoDFT/batched.py b/pyscf_ipu/nanoDFT/batched.py new file mode 100644 index 00000000..bf5f8394 --- /dev/null +++ b/pyscf_ipu/nanoDFT/batched.py @@ -0,0 +1,648 @@ +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np +import pyscf +import optax +from icecream import ic +from pyscf_ipu.exchange_correlation.b3lyp import b3lyp, vxc_b3lyp +from tqdm import tqdm +import time + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 + +def T(x): return jnp.transpose(x, (0,2,1)) + +# Only need to recompute: L_inv, grid_AO, grid_weights, H_core, ERI and E_nuc. +def dm_energy(W, state, diff_state, normal): + B, N, k = W.shape + L_inv_Q = state.L_inv_T @ jnp.linalg.qr(W)[0] # O(N^2 * num_electrons * batch) instead of O(N^3 * batch)! + density_matrix = 2 * L_inv_Q @ T(L_inv_Q) + E_xc = exchange_correlation(density_matrix, state, diff_state, normal) + diff_JK = JK(density_matrix, state, diff_state, normal) + energies = jnp.sum((density_matrix * (state.H_core + diff_JK/2)).reshape(B, -1), axis=-1) + E_xc + state.E_nuc + return jnp.sum(energies), (energies, E_xc, density_matrix) + +def exchange_correlation(density_matrix, state, diff_state, normal): + B, _, gsize, N = state.grid_AO.shape + if normal: + grid_AO_dm = (state.grid_AO[:, 0] @ density_matrix) # (B,gsize,N) @ (N, N) = O(B gsize N^2) + rho = jnp.sum(grid_AO_dm * state.grid_AO , axis=3) # (B,1,gsize,N) * (B,4,gsize,N) = O(B gsize N) + else: + def sparse_mult(values, dm): + in_ = dm.take(diff_state.cols, axis=0) + prod = in_*values[:, None] + return jax.ops.segment_sum(prod, diff_state.rows, gsize) + + main = diff_state.main_grid_AO[:1, 0] @ density_matrix # (1, gsize, N) @ (N, N) = O(gsize N^2) + correction = jax.vmap(sparse_mult)(diff_state.sparse_diffs_grid_AO, density_matrix) + grid_AO_dm = (main - correction).reshape(B, 1, gsize, N) + diff = diff_state.main_grid_AO[:1, :] - diff_state.diffs_grid_AO + rho = jnp.sum(grid_AO_dm * diff, axis=3).reshape(B, 4, gsize) + + E_xc = jax.vmap(b3lyp, in_axes=(0,None))(rho, EPSILON_B3LYP).reshape(B, gsize) + E_xc = jnp.sum(rho[:, 0] * state.grid_weights * E_xc, axis=-1).reshape(B) + return E_xc + +def JK(density_matrix, state, diff_state, normal): + if normal: + J = jnp.einsum('bijkl,bji->bkl', state.ERI, density_matrix) + K = jnp.einsum('bijkl,bjk->bil', state.ERI, density_matrix) + diff_JK = J - K / 2 * HYB_B3LYP + else: + from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import sparse_symmetric_einsum + # batched => flops = reads + #diff_JK = jax.vmap(sparse_symmetric_einsum, in_axes=(0, 0, 0))(state.nonzero_distinct_ERI, state.nonzero_indices, density_matrix) + # first + correction_remaining => floats = reads*batch_size + diff_JK = jax.vmap(sparse_symmetric_einsum, in_axes=(None, None, 0))(state.nonzero_distinct_ERI[0], state.nonzero_indices[0], density_matrix) + diff_JK = diff_JK - jax.vmap(sparse_symmetric_einsum, in_axes=(0, None, 0))(diff_state.diffs_ERI, diff_state.indxs, density_matrix) + + return diff_JK + + +def nanoDFT(mol_str, opts, pyscf_E): + # Init DFT tensors on CPU using PySCF. + # Try to re-use grid amongst all points. + state = init_dft(mol_str, opts) + c, w = state.grid_coords, state.grid_weights + print(mol_str[0][1]) + for _ in range(opts.bs-1): + mol_str[0][1] = (mol_str[0][1][0]+0.05, mol_str[0][1][1], mol_str[0][1][2]) + stateB = init_dft(mol_str, opts, c, w) + state = cat(state, stateB) + N = state.N[0] + + summary(state) + + if opts.normal: diff_state = None + else: + main_grid_AO = state.grid_AO[:1] + diffs_grid_AO = main_grid_AO - state.grid_AO + rows, cols = np.nonzero(np.max(diffs_grid_AO[:, 0]!=0, axis=0)) + sparse_diffs_grid_AO = diffs_grid_AO[:, 0, rows,cols] + + diff_ERIs = state.nonzero_distinct_ERI[:1] - state.nonzero_distinct_ERI + diff_indxs = state.nonzero_indices[0].reshape(1, -1, 4) + nzr = np.abs(diff_ERIs[1]).reshape(-1) != 0 + diff_ERIs = diff_ERIs[:, :, nzr] + diff_indxs = diff_indxs[:, nzr] + + diff_state = DiffState(indxs=diff_indxs, + rows=rows, cols=cols, + main_grid_AO=main_grid_AO, sparse_diffs_grid_AO=sparse_diffs_grid_AO, diffs_grid_AO=diffs_grid_AO, diffs_ERI=diff_ERIs) + summary(diff_state) + + if opts.visualize: + pass + + + w = state.init + vandg = jax.jit(jax.value_and_grad( dm_energy, has_aux=True), backend=opts.backend, static_argnames=("normal", )) + + # Build initializers for params + #adam = optax.adam(lr_schedule) + adam = optax.adabelief(opts.lr) + adam_state = adam.init(w) + + min_val = 0 + min_dm = 0 + + pbar = tqdm(range(opts.steps)) + + (val, _), grad = vandg(w, state, diff_state, opts.normal) + + for i in pbar: + #with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): + (val, (vals, E_xc, density_matrix)), grad = vandg(w, state, diff_state, opts.normal) + updates, adam_state = adam.update(grad, adam_state) + w = optax.apply_updates(w, updates) + #pbar.set_description("energy=%.7f [eV] error=%.7f [eV] (best_error=%.7f[eV])"%(vals*HARTREE_TO_EV, target-vals[0]*HARTREE_TO_EV, target-min_val*HARTREE_TO_EV)) + + if opts.bs == 1: pbar.set_description("error=%.7f [eV] (%.7f %.7f) "%(np.mean(val*HARTREE_TO_EV-state.pyscf_E), val*HARTREE_TO_EV, state.pyscf_E)) + else: + str = "error=" + "".join(["%.7f "%(vals[i]*HARTREE_TO_EV-state.pyscf_E[i]) for i in range(min(5,opts.bs))]) + " [eV]" + #str += "E_xc=" + "".join(["%.7f "%(E_xc[i]*HARTREE_TO_EV) for i in range(opts.bs)]) + " [eV]" + pbar.set_description(str) + if i == 0: print("") + + if val < min_val: + min_val = val + min_dm = density_matrix + + val, density_matrix = min_val, min_dm + + # needs batching + exit() + V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) + V_xc = (V_xc + V_xc.T)/2 + diff_JK = get_JK(density_matrix, state.ERI) + H = state.H_core + diff_JK + V_xc + mo_energy, mo_coeff = np.linalg.eigh(state.L_inv @ H @ state.L_inv.T) + mo_coeff = state.L_inv.T @ mo_coeff + + return val, (0, mo_energy, mo_coeff, state.grid_coords, state.grid_weights, density_matrix, H) + + +import chex +@chex.dataclass +class IterationState: + init: np.array + E_nuc: np.array + mask: np.array + L_inv: np.array + L_inv_T: np.array + H_core: np.array + grid_AO: np.array + grid_weights: np.array + grid_coords: np.array + pyscf_E: np.array + N: int + ERI: np.array + nonzero_distinct_ERI: np.array + nonzero_indices: np.array + +@chex.dataclass +class DiffState: + diffs_ERI: np.array + main_grid_AO: np.array + diffs_grid_AO: np.array + indxs: np.array + sparse_diffs_grid_AO: np.array#jax.experimental.sparse.csr.CSR + rows: np.array + cols: np.array + + +from pyscf.data.elements import charge as elements_proton +from pyscf.dft import gen_grid, radi + +def treutler_atomic_radii_adjust(mol, atomic_radii): + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + a = jnp.array(a) + + def fadjust(i, j, g): + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j] + g1 += g + return g1 + + return fadjust + + +def inter_distance(coords): + rr = np.linalg.norm(coords.reshape(-1, 1, 3) - coords, axis=2) + rr[np.diag_indices(rr.shape[0])] = 0 + return rr + +def original_becke(g): + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + g = (3 - g**2) * g * .5 + return g + +def get_partition( + mol, + atom_coords, + atom_grids_tab, + radii_adjust=treutler_atomic_radii_adjust, + atomic_radii=radi.BRAGG_RADII, + becke_scheme=original_becke, + concat=True +): + atm_dist = inter_distance(atom_coords) # [natom, natom] + + def gen_grid_partition(coords): + ngrids = coords.shape[0] + dc = coords[None] - atom_coords[:, None] + grid_dist = np.sqrt(np.einsum('ijk,ijk->ij', dc, dc)) # [natom, ngrid] + + ix, jx = np.tril_indices(mol.natm, k=-1) + + natm, ngrid = grid_dist.shape + g_ = -1 / atm_dist.reshape(natm, natm, 1) * (grid_dist.reshape(1, natm, ngrid) - grid_dist.reshape(natm, 1, ngrid)) + #g_ = jnp.array(g_) + + def pbecke_g(i, j): + g = g_[i, j] + charges = [elements_proton(x) for x in mol.elements] + rad = np.sqrt(atomic_radii[charges]) + 1e-200 + rr = rad.reshape(-1, 1) * (1. / rad) + a = .25 * (rr.T - rr) + a[a < -0.5] = -0.5 + a[a > 0.5] = 0.5 + g1 = g**2 + g1 -= 1. + g1 *= -a[i, j].reshape(-1, 1) + g1 += g + return g1 + + g = pbecke_g(ix, jx) + g = np.copy(becke_scheme(g)) + gp2 = (1+g)/2 + gm2 = (1-g)/2 + + pbecke = jnp.ones((mol.natm, ngrids)) # [natom, ngrid] + pbecke = pbecke.at[ix].mul(gm2) + pbecke = pbecke.at[jx].mul(gp2) + + return pbecke + + coords_all = [] + weights_all = [] + for ia in range(mol.natm): + coords, vol = atom_grids_tab[mol.atom_symbol(ia)] + coords = coords + atom_coords[ia] # [ngrid, 3] + pbecke = gen_grid_partition(coords) # [natom, ngrid] + weights = vol * pbecke[ia] / jnp.sum(pbecke, axis=0) + coords_all.append(coords) + weights_all.append(weights) + + if concat: + coords_all = jnp.vstack(coords_all) + weights_all = jnp.hstack(weights_all) + return coords_all, weights_all + + +class DifferentiableGrids(gen_grid.Grids): + """Differentiable alternative to the original pyscf.gen_grid.Grids.""" + + def build(self, atom_coords) : + mol = self.mol + + atom_grids_tab = self.gen_atomic_grids( + mol, self.atom_grid, self.radi_method, self.level, self.prune + ) + + coords, weights = get_partition( + mol, + atom_coords, + atom_grids_tab, + treutler_atomic_radii_adjust, + self.atomic_radii, + original_becke, + ) + self.coords = coords + self.weights = weights + return coords, weights + + +def grids_from_pyscf_mol( + mol: pyscf.gto.mole.Mole, quad_level: int = 1 +) : + g = gen_grid.Grids(mol) + g.level = quad_level + g.build() + grids = jnp.array(g.coords) + weights = jnp.array(g.weights) + return grids, weights + + +def init_dft(mol_str, opts, _coords=None, _weights=None): + mol = build_mol(mol_str, opts.basis) + pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + + N = mol.nao_nr() # N=66 for C6H6 (number of atomic **and** molecular orbitals) + n_electrons_half = mol.nelectron//2 # 21 for C6H6 + E_nuc = mol.energy_nuc() # float = 202.4065 [Hartree] for C6H6. TODO(): Port to jax. + + from pyscf import dft + #grids = pyscf.dft.gen_grid.Grids(mol) + grids = DifferentiableGrids(mol) + grids.level = opts.level + #grids.build() + grids.build(np.concatenate([np.array(a[1]).reshape(1, 3) for a in mol._atom])) + + grid_weights = grids.weights # (grid_size,) = (45624,) for C6H6 + grid_coords = grids.coords + coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' + grid_AO = mol.eval_gto(coord_str, grids.coords, 4) # (4, grid_size, N) = (4, 45624, 9) for C6H6. + + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. + kinetic = mol.intor_symmetric('int1e_kin') # (N,N) + nuclear = mol.intor_symmetric('int1e_nuc') # (N,N) + O = mol.intor_symmetric('int1e_ovlp') # (N,N) + L = np.linalg.cholesky(O) + L_inv = np.linalg.inv(L) # (N,N) + + + init = np.eye(N)[:, :n_electrons_half] + #I_nxk = init[:, :n_electrons_half] + + mask = np.concatenate([np.ones(n_electrons_half), np.zeros(N-n_electrons_half)]) + if opts.normal: + ERI = mol.intor("int2e_sph") + nonzero_distinct_ERI = np.zeros(1) + nonzero_indices = np.zeros(1) + else: + from pyscf_ipu.nanoDFT.sparse_symmetric_ERI import get_i_j, num_repetitions_fast + eri_threshold = 0 + batches = 1 + nipu = 1 + distinct_ERI = mol.intor("int2e_sph", aosym="s8") + #below_thr = np.abs(distinct_ERI) <= eri_threshold + #distinct_ERI[below_thr] = 0.0 + #ic(distinct_ERI.size, np.sum(below_thr), np.sum(below_thr)/distinct_ERI.size) + #nonzero_indices = np.nonzero(distinct_ERI)[0].astype(np.uint64) + nonzero_indices = np.arange(distinct_ERI.size)# ]np.nonzero(distinct_ERI)[0].astype(np.uint64) + nonzero_distinct_ERI = distinct_ERI[nonzero_indices]#.astype(np.float32) + + ij, kl = get_i_j(nonzero_indices) + rep = num_repetitions_fast(ij, kl) + nonzero_distinct_ERI = nonzero_distinct_ERI / rep + remainder = nonzero_indices.shape[0] % (nipu*batches) + + if remainder != 0: + ij = np.pad(ij, ((0,nipu*batches-remainder))) + kl = np.pad(kl, ((0,nipu*batches-remainder))) + nonzero_distinct_ERI = np.pad(nonzero_distinct_ERI, (0,nipu*batches-remainder)) + + ij = ij.reshape(batches, -1) + kl = kl.reshape(batches, -1) + nonzero_distinct_ERI = nonzero_distinct_ERI.reshape(batches, -1) + + i, j = get_i_j(ij.reshape(-1)) + k, l = get_i_j(kl.reshape(-1)) + nonzero_indices = np.vstack([i,j,k,l]).T.reshape(batches, -1, 4) + + #ERI = [nonzero_distinct_ERI, nonzero_indices] + #ERI = ERI + ERI = np.zeros(1) + #ERI = mol.intor("int2e_sph") + + def e(x): return np.expand_dims(x, axis=0) + + + state = IterationState(init = e(init), + E_nuc=e(E_nuc), + ERI=e(ERI), + nonzero_distinct_ERI=e(nonzero_distinct_ERI), + nonzero_indices=e(nonzero_indices), + mask=e(mask), + H_core=e(nuclear+kinetic), + L_inv=e(L_inv), + L_inv_T = e(L_inv.T), + grid_AO=e(grid_AO), + grid_weights=e(grid_weights), + grid_coords=e(grid_coords), + pyscf_E=e(pyscf_E[-1:]), + N=e(mol.nao_nr()), + ) + + + return state + + +def summary(state): + if state is None: return + print("_"*100) + for field_name, field_def in state.__dataclass_fields__.items(): + field_value = getattr(state, field_name) + try: + print("%20s %20s %20s"%(field_name,getattr(field_value, 'shape', None), getattr(field_value, "nbytes", None)/10**9)) + except: + print("BROKE FOR ", field_name) + try: + print(state.pyscf_E[:, -1]) + except: + pass + print("_"*100) + +def cat(dc1, dc2, axis=0): + # Use dictionary comprehension to iterate over the dataclass fields + concatenated_fields = { + field: jnp.concatenate([getattr(dc1, field), getattr(dc2, field)], axis=axis) + for field in dc1.__annotations__ + } + # Create a new dataclass instance with the concatenated fields + return IterationState(**concatenated_fields) + + +def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): + # Electronic part of RHF/RKS gradients + dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + dme0 = 2 * (mo_coeff * mask*mo_energy) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + + # Code identical to exchange correlation. + rho = jnp.sum( grid_AO[:1] @ dm0 * grid_AO, axis=2) # (10, grid_size) = (10, 45624) for C6H6. + _, vrho, vgamma = vxc_b3lyp(rho, EPSILON_B3LYP) # (grid_size,) (grid_size,) + V_xc = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma.reshape(1, -1)*rho[1:4]], axis=0) # (4, grid_size) + + vmat = grid_AO[1:4].transpose(0, 2, 1) @ jnp.sum(grid_AO[:4] * jnp.expand_dims(weight * V_xc, axis=2), axis=0) # (3, N, N) + aos = jnp.concatenate([jnp.expand_dims(grid_AO[np.array([1,4,5,6])], 0), jnp.expand_dims(grid_AO[np.array([2,5,7,8])], 0), jnp.expand_dims(grid_AO[np.array([3,6,8,9])], 0)], axis=0) # (3, N, N) + V_xc = - vmat - jnp.transpose(jnp.einsum("snpi,np->spi", aos, weight*V_xc), axes=(0,2,1)) @ grid_AO[0] # (3, 4, grid_size, N) + + vj = - jnp.einsum('sijkl,lk->sij', eri, dm0) # (3, N, N) + vk = - jnp.einsum('sijkl,jk->sil', eri, dm0) # (3, N, N) + vhf = V_xc + vj - vk * .5 * HYB_B3LYP # (3, N, N) + + de = jnp.einsum('lxij,ij->lx', h1aos, dm0) # (natm, 3) + for k, ia in enumerate(range(natm)): + p0, p1 = aoslices[ia][2], aoslices[ia][3] + de = de.at[k].add(jnp.einsum('xij,ij->x', vhf[:, p0:p1], dm0[p0:p1]) * 2) + de = de.at[k].add(-jnp.einsum('xij,ij->x', s1[:, p0:p1], dme0[p0:p1]) * 2) + return de + +def grad_nuc(charges, coords): + # Derivatives of nuclear repulsion energy wrt nuclear coordinates + natm = charges.shape[0] + pairwise_charges = charges.reshape(natm, 1) * charges.reshape(1, natm) # (natm, natm) + pairwise_difference = coords.reshape(1, natm, 3) - coords.reshape(natm, 1, 3) # (natm, natm, 3) + pairwise_distances = jnp.linalg.norm(pairwise_difference, axis=2) ** 3 # (natm, natm) + pairwise_distances = jnp.where(pairwise_distances == 0, jnp.inf, pairwise_distances) # (natm, natm) + all = - pairwise_charges.reshape(natm, natm, 1) * pairwise_difference # (natm, natm, 3) + all = all / pairwise_distances.reshape(natm, natm, 1) # (natm, natm, 3) + all = all.at[jnp.diag_indices(natm)].set(0) # (natm, natm, 3) + return jnp.sum(all, axis=0) # (natm, natm) + +def grad(mol, coords, weight, mo_coeff, mo_energy, dm, H): + # Initialize DFT tensors on CPU using PySCF. + ao = pyscf.dft.numint.NumInt().eval_ao(mol, coords, deriv=2) + eri = mol.intor("int2e_ip1") + s1 = - mol.intor('int1e_ipovlp', comp=3) + kin = - mol.intor('int1e_ipkin', comp=3) + nuc = - mol.intor('int1e_ipnuc', comp=3) + + mask = np.ones(mol.nao_nr()) + mask[mol.nelectron//2:] = 0 + + aoslices = mol.aoslice_by_atom() + h1 = kin + nuc + def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> + _, _, p0, p1 = aoslices[atm_id] + with mol.with_rinv_at_nucleus(atm_id): + vrinv = mol.intor('int1e_iprinv', comp=3) # + vrinv *= -mol.atom_charge(atm_id) + vrinv[:,p0:p1] += h1[:,p0:p1] + return vrinv + vrinv.transpose(0,2,1) + N = h1.shape[1] # (3, N , N) + h1aos = np.zeros((mol.natm, 3, N, N)) + for k, ia in enumerate(range(mol.natm)): + p0, p1 = aoslices[ia,2:] + h1aos[k] = hcore_deriv(ia, aoslices, h1) + + charges = np.zeros((mol.natm)) + coords = np.zeros((mol.natm,3)) + for j in range(mol.natm): + charges[j] = mol.atom_charge(j) + coords[j]= mol.atom_coord(j) + + #_grad_elec = jax.jit(grad_elec, static_argnames=["aoslices", "natm"], backend="cpu") + _grad_elec = grad_elec + _grad_nuc = jax.jit(grad_nuc, backend="cpu") + + return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff, mol, dm, H) + _grad_nuc(charges, coords) + +def pyscf_reference(mol_str, opts): + from pyscf import __config__ + __config__.dft_rks_RKS_grids_level = opts.level + + mol = build_mol(mol_str, opts.basis) + mol.max_cycle = 50 + mf = pyscf.scf.RKS(mol) + mf.max_cycle = 50 + mf.xc = "b3lyp" + mf.diis_space = 8 + pyscf_energies = [] + pyscf_hlgaps = [] + lumo = mol.nelectron//2 + homo = lumo - 1 + def callback(envs): + pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) + hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV + pyscf_hlgaps.append(hl_gap_hartree) + print("\rPYSCF: ", pyscf_energies[-1] , end="") + mf.callback = callback + mf.kernel() + print("") + forces = mf.nuc_grad_method().kernel() + return np.array(pyscf_energies), np.array(pyscf_hlgaps), np.array(forces) + +def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap): + #TODO(HH): rename to match caller variable names + nanoDFT_E = nanoDFT_E*HARTREE_TO_EV + print("pyscf:\t\t%15f"%pyscf_E[-1]) + print("us:\t\t%15f"%nanoDFT_E) + print("diff:\t\t%15f"%np.abs(pyscf_E[-1]-nanoDFT_E)) + print("chemAcc: \t%15f"%0.043) + print("chemAcc/diff: \t%15f"%(0.043/np.abs(pyscf_E[-1]-nanoDFT_E))) + print("") + + # Forces + print() + print("np.max(|nanoDFT_F-PySCF_F|):", np.max(np.abs(nanoDFT_forces-pyscf_forces))) + + norm_X = np.linalg.norm(nanoDFT_forces, axis=1) + norm_Y = np.linalg.norm(pyscf_forces, axis=1) + dot_products = np.sum(nanoDFT_forces * pyscf_forces, axis=1) + cosine_similarity = dot_products / (norm_X * norm_Y) + print("Force cosine similarity:",cosine_similarity) + +def build_mol(mol_str, basis_name): + mol = pyscf.gto.mole.Mole() + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + return mol + +def reference(mol_str, opts): + import pickle + import hashlib + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + print(filename) + if not os.path.exists(filename): + pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) + with open(filename, "wb") as file: + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + else: + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + return pyscf_E, pyscf_hlgap, pyscf_forces + + +if __name__ == "__main__": + #jax.config.FLAGS.jax_platform_name = 'cpu' + import os + import argparse + + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=1e-3) + parser.add_argument('-steps', type=int, default=200) + parser.add_argument('-bs', type=int, default=2) + + parser.add_argument('-normal', action="store_true") + parser.add_argument('-visualize', action="store_true") + opts = parser.parse_args() + + # benzene + if True: + mol_str = [ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ] + else: + mol_str = [ + ["N", (-1.3289 , 1.0488 , -1.5596)], + ["C", ( 0.1286 , 1.0198 , -1.8261)], + ["C", ( 0.3335 , 0.8585 , -3.3268)], + ["O", (-0.0551 , -0.0282 , -4.0649)], + ["O", ( 1.0668 , 1.8338 , -3.9108)], + ["C", ( 0.8906 , -0.1043 , -1.0999)], + ["H", ( 1.9534 , -0.0888 , -1.4126)], + ["H", ( 0.4975 , -1.0987 , -1.3971)], + ["C", ( 0.8078 , 0.0465 , 0.3677)], + ["C", ( 1.5802 , 0.8809 , 1.1516)], + ["N", ( 1.1567 , 0.7746 , 2.4944)], + ["H", ( 1.7094 , 1.0499 , 3.2650)], + ["C", ( 0.1694 , -0.2350 , 2.5662)], + ["C", (-0.0897 , -0.6721 , 1.2403)], + ["C", (-1.0740 , -1.6418 , 1.0106)], + ["H", (-1.2812 , -1.9849 , -0.0088)], + ["C", (-1.7623 , -2.1470 , 2.0948)], + ["H", (-2.5346 , -2.9080 , 1.9416)], + ["C", (-1.4948 , -1.7069 , 3.4060)], + ["H", (-2.0660 , -2.1385 , 4.2348)], + ["C", (-0.5337 , -0.7507 , 3.6638)], + ["H", (-0.3249 , -0.4086 , 4.6819)], + ["H", ( 2.3719 , 1.5631 , 0.8380)], + ["H", (-1.4726 , 1.2086 , -0.5841)], + ["H", (-1.7404 , 0.1740 , -1.8129)], + ["H", ( 0.5299 , 2.0096 , -1.4901)], + ["H", ( 1.1361 , 1.6737 , -4.8470)], + ] + + + #pos = [np.array(a[1]).reshape(1, 1) for a in mol_str] + #distances = map(lambda x: np.linalg.norm(np.array(x[0]) - np.array(x[1])), combinations(coords, 2)) + #return min(distances) + + + + + mol = build_mol(mol_str, opts.basis) + ic(mol.nao_nr()) + ic(mol.nelectron) + + pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_str, opts, pyscf_E) + nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) + + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap) diff --git a/pyscf_ipu/nanoDFT/gd.py b/pyscf_ipu/nanoDFT/gd.py new file mode 100644 index 00000000..e906af72 --- /dev/null +++ b/pyscf_ipu/nanoDFT/gd.py @@ -0,0 +1,315 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +# (assumes newest Jax) +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp +import numpy as np +import pyscf +import optax +from icecream import ic +from pyscf_ipu.exchange_correlation.b3lyp import b3lyp, vxc_b3lyp +from tqdm import tqdm + +HARTREE_TO_EV, EPSILON_B3LYP, HYB_B3LYP = 27.2114079527, 1e-20, 0.2 +def orth(x): return jnp.linalg.qr(x)[0] + +def dm_energy(weights: NxK, state): + eigvects: NxK = state.L_inv.T @ orth(weights) + density_matrix: NxN = 2 * eigvects @ eigvects.T + E_xc: float = exchange_correlation(density_matrix, state.grid_AO, state.grid_weights) + diff_JK: NxN = get_JK(density_matrix, state.ERI) + energy: float = jnp.sum(density_matrix * (state.H_core + diff_JK/2)) + E_xc + state.E_nuc + return energy, density_matrix + +def exchange_correlation(density_matrix: NxN, grid_AO: _4xGsizexN, grid_weights: gsize): + grid_AO_dm: _1xGsizexN = jnp.expand_dims(grid_AO[0] @ density_matrix) # O(gsize N^2) flops and gsizeN reads. + mult: _4xGsizexN = grid_AO_dm * grid_AO + rho: _4xGsize = jnp.sum(mult, axis=2) + E_xc: Gsize = b3lyp(rho, EPSILON_B3LYP) + E_xc: float = jnp.sum(rho[0] * grid_weights * E_xc) + return E_xc + +def get_JK(density_matrix: NxN, ERI: NxNxNxN): + J: (N, N) = jnp.einsum('ijkl,ji->kl', ERI, density_matrix) + K: (N, N) = jnp.einsum('ijkl,jk->il', ERI, density_matrix) + return J - (K / 2 * HYB_B3LYP) + +def nanoDFT(mol_str, opts, pyscf_E): + # Init DFT tensors on CPU using PySCF. + mol = build_mol(mol_str, opts.basis) + pyscf_E, pyscf_hlgap, pycsf_forces = reference(mol_str, opts) + + N = mol.nao_nr() + state = init_dft(mol, opts)[0] + target = pyscf_E[-1] + + w = np.eye(N) + np.random.normal(0, 0.01, (N, N)) + + vandg = jax.jit(jax.value_and_grad( dm_energy, has_aux=True), backend=opts.backend) + + # Build initializers for params + adam = optax.adam(opts.lr) + adam_state = adam.init(w) + + pbar = tqdm(range(opts.steps)) + for i in pbar: + (val, density_matrix), grad = vandg(w, state) + updates, adam_state = adam.update(grad, adam_state) + w = optax.apply_updates(w, updates) + pbar.set_description("energy=%.7f [eV] error=%.7f [eV]"%(val*HARTREE_TO_EV, target-val*HARTREE_TO_EV)) + if i == 0: print("") + + V_xc = jax.grad(exchange_correlation)(density_matrix, state.grid_AO, state.grid_weights) + V_xc = (V_xc + V_xc.T)/2 + diff_JK = get_JK(density_matrix, state.ERI) + H = state.H_core + diff_JK + V_xc + mo_energy, mo_coeff = np.linalg.eigh(state.L_inv @ H @ state.L_inv.T) + mo_coeff = state.L_inv.T @ mo_coeff + + return val, (0, mo_energy, mo_coeff, state.grid_coords, state.grid_weights, density_matrix, H) + + +import chex +@chex.dataclass +class IterationState: + E_nuc: np.array + density_matrix: np.array + kinetic: np.array + nuclear: np.array + O: np.array + mask: np.array + L_inv: np.array + L: np.array + H_core: np.array + grid_AO: np.array + grid_weights: np.array + atom_pos: np.array + ERI: np.array + grid_coords: np.array + +def init_dft(mol, opts): + N = mol.nao_nr() + n_electrons_half = mol.nelectron//2 + E_nuc = mol.energy_nuc() + + from pyscf import dft + grids = pyscf.dft.gen_grid.Grids(mol) + grids.level = opts.level + grids.build() + grid_weights = grids.weights + coord_str = 'GTOval_cart_deriv1' if mol.cart else 'GTOval_sph_deriv1' + grid_AO = mol.eval_gto(coord_str, grids.coords, 4) + grid_coords = grids.coords + density_matrix = pyscf.scf.hf.init_guess_by_minao(mol) + + # TODO(): Add integral math formulas for kinetic/nuclear/O/ERI. + kinetic = mol.intor_symmetric('int1e_kin') + nuclear = mol.intor_symmetric('int1e_nuc') + O = mol.intor_symmetric('int1e_ovlp') + L = np.linalg.cholesky(O) + L_inv = np.linalg.inv(L) + + mask = np.concatenate([np.ones(n_electrons_half), np.zeros(N-n_electrons_half)]) + + ERI = mol.intor("int2e_sph") + + state = IterationState(E_nuc=E_nuc, ERI=ERI, grid_coords=grid_coords, + density_matrix=density_matrix, kinetic=kinetic, + nuclear=nuclear, + O=O, + mask=mask, + H_core=nuclear+kinetic, + L_inv=L_inv, L=L, grid_AO=grid_AO, grid_weights=grid_weights, atom_pos=mol.atom_coords()) + + print("DFT Tensor Summary") + for field_name, field_def in state.__dataclass_fields__.items(): + field_value = getattr(state, field_name) + print(f"{field_name}: {getattr(field_value, 'shape', None)}") + + return state, n_electrons_half, E_nuc, N, L_inv, grid_weights, grid_coords, grid_AO + + +def grad_elec(weight, grid_AO, eri, s1, h1aos, natm, aoslices, mask, mo_energy, mo_coeff, mol, dm, H): + # Electronic part of RHF/RKS gradients + dm0 = 2 * (mo_coeff*mask) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + dme0 = 2 * (mo_coeff * mask*mo_energy) @ mo_coeff.T # (N, N) = (66, 66) for C6H6. + + # Code identical to exchange correlation. + rho = jnp.sum( grid_AO[:1] @ dm0 * grid_AO, axis=2) # (10, grid_size) = (10, 45624) for C6H6. + _, vrho, vgamma = vxc_b3lyp(rho, EPSILON_B3LYP) # (grid_size,) (grid_size,) + V_xc = jnp.concatenate([vrho.reshape(1, -1)/2, 4*vgamma.reshape(1, -1)*rho[1:4]], axis=0) # (4, grid_size) + + vmat = grid_AO[1:4].transpose(0, 2, 1) @ jnp.sum(grid_AO[:4] * jnp.expand_dims(weight * V_xc, axis=2), axis=0) # (3, N, N) + aos = jnp.concatenate([jnp.expand_dims(grid_AO[np.array([1,4,5,6])], 0), jnp.expand_dims(grid_AO[np.array([2,5,7,8])], 0), jnp.expand_dims(grid_AO[np.array([3,6,8,9])], 0)], axis=0) # (3, N, N) + V_xc = - vmat - jnp.transpose(jnp.einsum("snpi,np->spi", aos, weight*V_xc), axes=(0,2,1)) @ grid_AO[0] # (3, 4, grid_size, N) + + vj = - jnp.einsum('sijkl,lk->sij', eri, dm0) # (3, N, N) + vk = - jnp.einsum('sijkl,jk->sil', eri, dm0) # (3, N, N) + vhf = V_xc + vj - vk * .5 * HYB_B3LYP # (3, N, N) + + de = jnp.einsum('lxij,ij->lx', h1aos, dm0) # (natm, 3) + for k, ia in enumerate(range(natm)): + p0, p1 = aoslices[ia][2], aoslices[ia][3] + de = de.at[k].add(jnp.einsum('xij,ij->x', vhf[:, p0:p1], dm0[p0:p1]) * 2) + de = de.at[k].add(-jnp.einsum('xij,ij->x', s1[:, p0:p1], dme0[p0:p1]) * 2) + return de + +def grad_nuc(charges, coords): + # Derivatives of nuclear repulsion energy wrt nuclear coordinates + natm = charges.shape[0] + pairwise_charges = charges.reshape(natm, 1) * charges.reshape(1, natm) # (natm, natm) + pairwise_difference = coords.reshape(1, natm, 3) - coords.reshape(natm, 1, 3) # (natm, natm, 3) + pairwise_distances = jnp.linalg.norm(pairwise_difference, axis=2) ** 3 # (natm, natm) + pairwise_distances = jnp.where(pairwise_distances == 0, jnp.inf, pairwise_distances) # (natm, natm) + all = - pairwise_charges.reshape(natm, natm, 1) * pairwise_difference # (natm, natm, 3) + all = all / pairwise_distances.reshape(natm, natm, 1) # (natm, natm, 3) + all = all.at[jnp.diag_indices(natm)].set(0) # (natm, natm, 3) + return jnp.sum(all, axis=0) # (natm, natm) + +def grad(mol, coords, weight, mo_coeff, mo_energy, dm, H): + # Initialize DFT tensors on CPU using PySCF. + ao = pyscf.dft.numint.NumInt().eval_ao(mol, coords, deriv=2) + eri = mol.intor("int2e_ip1") + s1 = - mol.intor('int1e_ipovlp', comp=3) + kin = - mol.intor('int1e_ipkin', comp=3) + nuc = - mol.intor('int1e_ipnuc', comp=3) + + mask = np.ones(mol.nao_nr()) + mask[mol.nelectron//2:] = 0 + + aoslices = mol.aoslice_by_atom() + h1 = kin + nuc + def hcore_deriv(atm_id, aoslices, h1): # <\nabla|1/r|> + _, _, p0, p1 = aoslices[atm_id] + with mol.with_rinv_at_nucleus(atm_id): + vrinv = mol.intor('int1e_iprinv', comp=3) # + vrinv *= -mol.atom_charge(atm_id) + vrinv[:,p0:p1] += h1[:,p0:p1] + return vrinv + vrinv.transpose(0,2,1) + N = h1.shape[1] # (3, N , N) + h1aos = np.zeros((mol.natm, 3, N, N)) + for k, ia in enumerate(range(mol.natm)): + p0, p1 = aoslices[ia,2:] + h1aos[k] = hcore_deriv(ia, aoslices, h1) + + charges = np.zeros((mol.natm)) + coords = np.zeros((mol.natm,3)) + for j in range(mol.natm): + charges[j] = mol.atom_charge(j) + coords[j]= mol.atom_coord(j) + + #_grad_elec = jax.jit(grad_elec, static_argnames=["aoslices", "natm"], backend="cpu") + _grad_elec = grad_elec + _grad_nuc = jax.jit(grad_nuc, backend="cpu") + + return _grad_elec(weight, ao, eri, s1, h1aos, mol.natm, tuple([tuple(a) for a in aoslices.tolist()]), mask, mo_energy, mo_coeff, mol, dm, H) + _grad_nuc(charges, coords) + +def pyscf_reference(mol_str, opts): + from pyscf import __config__ + __config__.dft_rks_RKS_grids_level = opts.level + + mol = build_mol(mol_str, opts.basis) + mol.max_cycle = 50 + mf = pyscf.scf.RKS(mol) + mf.max_cycle = 50 + mf.xc = "b3lyp" + mf.diis_space = 8 + pyscf_energies = [] + pyscf_hlgaps = [] + lumo = mol.nelectron//2 + homo = lumo - 1 + def callback(envs): + pyscf_energies.append(envs["e_tot"]*HARTREE_TO_EV) + hl_gap_hartree = np.abs(envs["mo_energy"][homo] - envs["mo_energy"][lumo]) * HARTREE_TO_EV + pyscf_hlgaps.append(hl_gap_hartree) + print("\rPYSCF: ", pyscf_energies[-1] , end="") + mf.callback = callback + mf.kernel() + print("") + forces = mf.nuc_grad_method().kernel() + return np.array(pyscf_energies), np.array(pyscf_hlgaps), np.array(forces) + +def print_difference(nanoDFT_E, nanoDFT_forces, nanoDFT_logged_E, nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap): + #TODO(HH): rename to match caller variable names + nanoDFT_E = nanoDFT_E*HARTREE_TO_EV + print("pyscf:\t\t%15f"%pyscf_E[-1]) + print("us:\t\t%15f"%nanoDFT_E) + print("diff:\t\t%15f"%np.abs(pyscf_E[-1]-nanoDFT_E)) + print("chemAcc: \t%15f"%0.043) + print("chemAcc/diff: \t%15f"%(0.043/np.abs(pyscf_E[-1]-nanoDFT_E))) + print("") + + # Forces + print() + print("np.max(|nanoDFT_F-PySCF_F|):", np.max(np.abs(nanoDFT_forces-pyscf_forces))) + + norm_X = np.linalg.norm(nanoDFT_forces, axis=1) + norm_Y = np.linalg.norm(pyscf_forces, axis=1) + dot_products = np.sum(nanoDFT_forces * pyscf_forces, axis=1) + cosine_similarity = dot_products / (norm_X * norm_Y) + print("Force cosine similarity:",cosine_similarity) + +def build_mol(mol_str, basis_name): + mol = pyscf.gto.mole.Mole() + mol.build(atom=mol_str, unit="Angstrom", basis=basis_name, spin=0, verbose=0) + return mol + +def reference(mol_str, opts): + import pickle + import hashlib + import os + os.makedirs("precomputed", exist_ok=True) + filename = "precomputed/%s.pkl"%hashlib.sha256((str(mol_str) + str(opts.basis) + str(opts.level)).encode('utf-8')).hexdigest() + print(filename) + if not os.path.exists(filename): + pyscf_E, pyscf_hlgap, pyscf_forces = pyscf_reference(mol_str, opts) + with open(filename, "wb") as file: + pickle.dump([pyscf_E, pyscf_hlgap, pyscf_forces], file) + else: + pyscf_E, pyscf_hlgap, pyscf_forces = pickle.load(open(filename, "rb")) + return pyscf_E, pyscf_hlgap, pyscf_forces + + +if __name__ == "__main__": + #jax.config.FLAGS.jax_platform_name = 'cpu' + import os + import argparse + + parser = argparse.ArgumentParser() + # DFT options + parser.add_argument('-basis', type=str, default="sto3g") + parser.add_argument('-level', type=int, default=0) + # GD options + parser.add_argument('-backend', type=str, default="cpu") + parser.add_argument('-lr', type=float, default=1e-3) + parser.add_argument('-steps', type=int, default=200) + opts = parser.parse_args() + + # benzene + mol_str = [ + ["C", ( 0.0000, 0.0000, 0.0000)], + ["C", ( 1.4000, 0.0000, 0.0000)], + ["C", ( 2.1000, 1.2124, 0.0000)], + ["C", ( 1.4000, 2.4249, 0.0000)], + ["C", ( 0.0000, 2.4249, 0.0000)], + ["C", (-0.7000, 1.2124, 0.0000)], + ["H", (-0.5500, -0.9526, 0.0000)], + ["H", (-0.5500, 3.3775, 0.0000)], + ["H", ( 1.9500, -0.9526, 0.0000)], + ["H", (-1.8000, 1.2124, 0.0000)], + ["H", ( 3.2000, 1.2124, 0.0000)], + ["H", ( 1.9500, 3.3775, 0.0000)] + ] + + + mol = build_mol(mol_str, opts.basis) + + ic(mol.nao_nr()) + ic(mol.nelectron) + + pyscf_E, pyscf_hlgap, pyscf_forces = reference(mol_str, opts) + + nanoDFT_E, (nanoDFT_hlgap, mo_energy, mo_coeff, grid_coords, grid_weights, dm, H) = nanoDFT(mol_str, opts, pyscf_E) + nanoDFT_forces = grad(mol, grid_coords, grid_weights, mo_coeff, mo_energy, np.array(dm), np.array(H)) + + print_difference(nanoDFT_E, nanoDFT_forces, 0 , nanoDFT_hlgap, pyscf_E, pyscf_forces, pyscf_hlgap)