Skip to content

Commit d6c74f7

Browse files
committed
Support promolecules with single atom
1 parent abf1008 commit d6c74f7

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

atomdb/promolecule.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
import numpy as np
2525
from scipy.optimize import linprog
2626

27-
from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_REMOTE, DEFAULT_DATASET, MULTIPLICITIES
2827
from atomdb.periodic import element_number, element_symbol
2928
from atomdb.species import load
29+
from atomdb.utils import DEFAULT_DATAPATH, DEFAULT_DATASET, DEFAULT_REMOTE, MULTIPLICITIES
3030

3131
__all__ = [
3232
"Promolecule",
@@ -566,6 +566,10 @@ def make_promolecule(
566566
Promolecule instance.
567567
568568
"""
569+
# Convert single coord [x, y, z] to list of coords [[x, y, z]]
570+
coords = np.asarray(coords, dtype=float)
571+
if coords.ndim == 1:
572+
coords = coords.reshape(1, -1)
569573
# Check coordinate units
570574
if units is None or units.lower() == "bohr":
571575
coords = [coord / 1 for coord in coords]
@@ -574,9 +578,12 @@ def make_promolecule(
574578
else:
575579
raise ValueError(f"Invalid `units` parameter '{units}'; " "must be 'bohr' or 'angstrom'")
576580

581+
# Convert single atnum to list of atnums [atnum]
582+
if isinstance(atnums, (Integral, str)):
583+
atnums = [atnums]
577584
# Get atomic symbols/numbers from inputs
578-
atoms = [element_symbol(atom) for atom in atnums]
579585
atnums = [element_number(atom) for atom in atnums]
586+
atoms = [element_symbol(atom) for atom in atnums]
580587

581588
# Handle default charge parameters
582589
if charges is None:

0 commit comments

Comments
 (0)