diff --git a/pyscf_ipu/nanoDFT/nanoDFT.py b/pyscf_ipu/nanoDFT/nanoDFT.py index b50e3ad..f6678e3 100644 --- a/pyscf_ipu/nanoDFT/nanoDFT.py +++ b/pyscf_ipu/nanoDFT/nanoDFT.py @@ -601,7 +601,10 @@ def nanoDFT_options( if mol_str is None: exit(1) - print(f"Minimum interatomic distance: {utils.min_interatomic_distance(mol_str)}") # TODO: dies for --mol_str methane + m_i_d = utils.min_interatomic_distance(mol_str) + print(f"Minimum interatomic distance: {m_i_d}") # TODO: dies for --mol_str methane + if m_i_d < 0.7 or m_i_d > 1.8: + print("WARNING: the coordinates may be expressed in units other than angstrom") args = locals() mol_str = args["mol_str"] diff --git a/pyscf_ipu/nanoDFT/utils.py b/pyscf_ipu/nanoDFT/utils.py index f9e1197..1d89f50 100644 --- a/pyscf_ipu/nanoDFT/utils.py +++ b/pyscf_ipu/nanoDFT/utils.py @@ -3,6 +3,7 @@ import sys import h5py import pubchempy +import pyscf import numpy as np from itertools import combinations from operator import itemgetter @@ -63,7 +64,7 @@ def get_mol_str_pubchem(entry: str): if len(compound) == 0: compound = pubchempy.get_compounds(entry, 'cid', record_type='2d') else: # if not, we assume it is a name - print(f"Searching in PubChem for compound with name '{entry}'") + print(f"Searching in PubChem by name") compound = pubchempy.get_compounds(entry, 'name', record_type='3d') if len(compound) == 0: compound = pubchempy.get_compounds(entry, 'name', record_type='2d') @@ -71,7 +72,7 @@ def get_mol_str_pubchem(entry: str): if len(compound) > 1: print("INFO: PubChem returned more than one compound; using the first...", file=sys.stderr) elif len(compound) == 0: - print(f"No compound found with the name '{entry}' in PubChem") + print(f"No compound found in PubChem by CID or by name") return None print(f"Found compound: {compound[0].synonyms[0]}") for a in compound[0].atoms: @@ -138,9 +139,11 @@ def process_mol_str(mol_str: str): elif mol_str in spice_amino_acids: mol_str = get_mol_str_spice_aa(mol_str) else: - mol_str = get_mol_str_pubchem(mol_str) - - return mol_str + mol_str_pubchem = get_mol_str_pubchem(mol_str) + if mol_str_pubchem is not None: + return mol_str_pubchem + + return pyscf.gto.mole.Mole().format_atom(mol_str) def min_interatomic_distance(mol_str):