Skip to content

Commit

Permalink
Added some utilities for rotational averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Sep 25, 2024
1 parent 1fe3102 commit 488239d
Show file tree
Hide file tree
Showing 2 changed files with 347 additions and 0 deletions.
316 changes: 316 additions & 0 deletions ipi/engine/forcefields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,6 +1371,322 @@ def poll(self):
r["status"] = "Done"


class FFRotations(FFSocket):
"""Forcefield to manipulate models that are not exactly rotationally equivariant.
Can be used to evaluate a different random rotation at each evaluation, or to average
over a regular grid of Euler angles"""

def __init__(
self,
latency=1.0,
offset=0.0,
name="",
pars=None,
dopbc=True,
active=np.array([-1]),
threaded=True,
interface=None,
random=False,
improper=False,
grid=1,
):
super(FFRotations, self).__init__(
self,
latency,
offset,
name,
pars,
dopbc,
active,
threaded,
interface,
)

self.random = random
self.improper = improper
self._rotations = []

if len(fflist) == 0:

Check warning on line 1409 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'fflist'
raise ValueError(
"Committee forcefield cannot be initialized from an empty list"
)
self.fflist = fflist

Check warning on line 1413 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'fflist'
self.ff_requests = {}
self.baseline_uncertainty = baseline_uncertainty

Check warning on line 1415 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'baseline_uncertainty'
self.baseline_name = baseline_name

Check warning on line 1416 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'baseline_name'
if len(ffweights) == 0 and self.baseline_uncertainty < 0:

Check warning on line 1417 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'ffweights'
ffweights = np.ones(len(fflist))

Check warning on line 1418 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'fflist'
elif len(ffweights) == 0 and self.baseline_uncertainty > 0:
ffweights = np.ones(len(fflist) - 1)

Check warning on line 1420 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'fflist'
if len(ffweights) != len(fflist) and self.baseline_uncertainty < 0:

Check warning on line 1421 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'fflist'
raise ValueError("List of weights does not match length of committee model")
elif len(ffweights) != len(fflist) - 1 and self.baseline_uncertainty > 0:

Check warning on line 1423 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'fflist'
raise ValueError("List of weights does not match length of committee model")
if (self.baseline_name == "") != (self.baseline_uncertainty < 0):
raise ValueError(
"Name and the uncertainty of the baseline are not simultaneously defined"
)
self.ffweights = ffweights
self.alpha = alpha

Check warning on line 1430 in ipi/engine/forcefields.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'alpha'
self.active_thresh = active_thresh
self.active_out = active_out
self.parse_json = parse_json

def bind(self, output_maker):
super(FFCommittee, self).bind(output_maker)
if self.active_thresh > 0:
if self.active_out is None:
raise ValueError(
"Must specify an output file if you want to save structures for active learning"
)
else:
self.active_file = self.output_maker.get_output(self.active_out, "w")

def start(self):
for ff in self.fflist:
ff.start()
super(FFCommittee, self).start()

def queue(self, atoms, cell, reqid=-1):
# launches requests for all of the committee FF objects
ffh = []
for ff in self.fflist:
ffh.append(ff.queue(atoms, cell, reqid))

# creates the request with the help of the base class,
# making sure it already contains a handle to the list of FF
# requests
req = super(FFCommittee, self).queue(
atoms, cell, reqid, template=dict(ff_handles=ffh)
)
req["t_dispatched"] = time.time()
return req

def check_finish(self, r):
"""Checks if all sub-requests associated with a given
request are finished"""
for ff_r in r["ff_handles"]:
if ff_r["status"] != "Done":
return False
return True

def gather(self, r):
"""Collects results from all sub-requests, and assemble the committee of models."""

r["result"] = [
0.0,
np.zeros(len(r["pos"]), float),
np.zeros((3, 3), float),
"",
]

# list of pointers to the forcefield requests. shallow copy so we can remove stuff
com_handles = r["ff_handles"].copy()
if self.baseline_name != "":
# looks for the baseline potential, store its value and drops it from the list
names = [ff.name for ff in self.fflist]

for i, ff_r in enumerate(com_handles):
if names[i] == self.baseline_name:
baseline_pot = ff_r["result"][0]
baseline_frc = ff_r["result"][1]
baseline_vir = ff_r["result"][2]
baseline_xtr = ff_r["result"][3]
com_handles.pop(i)
break

# Gathers the forcefield energetics and extras
pots = []
frcs = []
virs = []
xtrs = []

all_have_frc = True
all_have_vir = True

for ff_r in com_handles:
# if required, tries to extract multiple committe members from the extras JSON string
if "committee_pot" in ff_r["result"][3] and self.parse_json:
pots += ff_r["result"][3]["committee_pot"]
if "committee_force" in ff_r["result"][3]:
frcs += ff_r["result"][3]["committee_force"]
ff_r["result"][3].pop("committee_force")
else:
# if the commitee doesn't have forces, just add the mean force from this model
frcs.append(ff_r["result"][1])
warning("JSON committee doesn't have forces", verbosity.medium)
all_have_frc = False

if "committee_virial" in ff_r["result"][3]:
virs += ff_r["result"][3]["committee_virial"]
ff_r["result"][3].pop("committee_virial")
else:
# if the commitee doesn't have virials, just add the mean virial from this model
virs.append(ff_r["result"][2])
warning("JSON committee doesn't have virials", verbosity.medium)
all_have_vir = False

else:
pots.append(ff_r["result"][0])
frcs.append(ff_r["result"][1])
virs.append(ff_r["result"][2])

pots = np.array(pots)
if len(pots) != len(frcs) and len(frcs) > 1:
raise ValueError(
"If the committee returns forces, we need *all* components"
)
frcs = np.array(frcs).reshape(len(frcs), -1)

if len(pots) != len(virs) and len(virs) > 1:
raise ValueError(
"If the committee returns virials, we need *all* components"
)
virs = np.array(virs).reshape(-1, 3, 3)

xtrs.append(ff_r["result"][3])

# Computes the mean energetics
mean_pot = np.mean(pots, axis=0)
mean_frc = np.mean(frcs, axis=0)
mean_vir = np.mean(virs, axis=0)

# Rescales the committee energetics so that their standard deviation corresponds to the error
rescaled_pots = np.asarray(
[mean_pot + self.alpha * (pot - mean_pot) for pot in pots]
)
rescaled_frcs = np.asarray(
[mean_frc + self.alpha * (frc - mean_frc) for frc in frcs]
)
rescaled_virs = np.asarray(
[mean_vir + self.alpha * (vir - mean_vir) for vir in virs]
)

# Calculates the error associated with the committee
var_pot = np.var(rescaled_pots, ddof=1)
std_pot = np.sqrt(var_pot)

if self.baseline_name != "":
if not (all_have_frc and all_have_vir):
raise ValueError(
"Cannot use weighted baseline without a force ensemble"
)

# Computes the additional component of the energetics due to a position
# dependent weight. This is based on the assumption that V_committee is
# a correction over the baseline, that V = V_baseline + V_committe, that
# V_baseline has an uncertainty given by baseline_uncertainty,
# and V_committee the committee error. Then
# V = V_baseline + s_b^2/(s_c^2+s_b^2) V_committe

s_b2 = self.baseline_uncertainty**2

nmodels = len(pots)
uncertain_frc = (
self.alpha**2
* np.sum(
[
(pot - mean_pot) * (frc - mean_frc)
for pot, frc in zip(pots, frcs)
],
axis=0,
)
/ (nmodels - 1)
)
uncertain_vir = (
self.alpha**2
* np.sum(
[
(pot - mean_pot) * (vir - mean_vir)
for pot, vir in zip(pots, virs)
],
axis=0,
)
/ (nmodels - 1)
)

# Computes the final average energetics
final_pot = baseline_pot + mean_pot * s_b2 / (s_b2 + var_pot)
final_frc = (
baseline_frc
+ mean_frc * s_b2 / (s_b2 + var_pot)
- 2.0 * mean_pot * s_b2 / (s_b2 + var_pot) ** 2 * uncertain_frc
)
final_vir = (
baseline_vir
+ mean_vir * s_b2 / (s_b2 + var_pot)
- 2.0 * mean_pot * s_b2 / (s_b2 + var_pot) ** 2 * uncertain_vir
)

# Sets the output of the committee model.
r["result"][0] = final_pot
r["result"][1] = final_frc
r["result"][2] = final_vir
else:
# Sets the output of the committee model.
r["result"][0] = mean_pot
r["result"][1] = mean_frc
r["result"][2] = mean_vir

r["result"][3] = {
"committee_pot": rescaled_pots,
"committee_uncertainty": std_pot,
}

if all_have_frc:
r["result"][3]["committee_force"] = rescaled_frcs.reshape(
len(rescaled_pots), -1
)
if all_have_vir:
r["result"][3]["committee_virial"] = rescaled_virs.reshape(
len(rescaled_pots), -1
)

if self.baseline_name != "":
r["result"][3]["baseline_pot"] = (baseline_pot,)
r["result"][3]["baseline_force"] = (baseline_frc,)
r["result"][3]["baseline_virial"] = ((baseline_vir.flatten()),)
r["result"][3]["baseline_extras"] = (baseline_xtr,)
r["result"][3]["wb_mixing"] = (s_b2 / (s_b2 + var_pot),)

# "dissolve" the extras dictionaries into a list
for k in xtrs[0].keys():
if ("committee_" + k) in r["result"][3].keys():
raise ValueError(
"Name clash between extras key "
+ k
+ " and default committee extras"
)
r["result"][3][("committee_" + k)] = []
for x in xtrs:
r["result"][3][("committee_" + k)].append(x[k])

if self.active_thresh > 0.0 and std_pot > self.active_thresh:
dumps = json.dumps(
{
"position": list(r["pos"]),
"cell": list(r["cell"][0].flatten()),
"uncertainty": std_pot,
}
)
self.active_file.write(dumps)

# releases the requests from the committee FF
for ff, ff_r in zip(self.fflist, r["ff_handles"]):
ff.release(ff_r)

def poll(self):
"""Polls the forcefield object to check if it has finished."""

with self._threadlock:
for r in self.requests:
if r["status"] != "Done" and self.check_finish(r):
r["t_finished"] = time.time()
self.gather(r)
r["result"][0] -= self.offset
r["status"] = "Done"


class PhotonDriver:
"""
Photon driver for a single cavity mode
Expand Down
31 changes: 31 additions & 0 deletions ipi/utils/mathtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,37 @@ def euler_zxz_to_matrix(theta, v, w):
return rotation_matrix


def roots_legendre(L):
"""Replicates scipy.special.roots_legendre using only numpy"""

legendre_poly = np.polynomial.legendre.Legendre.basis(L)
roots = np.polynomial.legendre.legroots(legendre_poly.coef)
legendre_poly_deriv = legendre_poly.deriv()

# Calculate weights using the formula
weights = 2 / ((1 - roots**2) * (legendre_poly_deriv(roots) ** 2))

return roots, weights


def get_rotation_quadrature(L):
matrices, weights = [], []
for theta_index in range(0, 2 * L - 1):
for w_index in range(0, 2 * L - 1):
theta = 2 * np.pi * theta_index / (2 * L - 1)
w = 2 * np.pi * w_index / (2 * L - 1)
roots_legendre_now, weights_now = roots_legendre(L)
all_v = np.arccos(roots_legendre_now)
for v, weight in zip(all_v, weights_now):
weights.append(weight)
angles = [theta, v, w]
rotation = R.from_euler("zxz", angles, degrees=False)
rotation_matrix = rotation.as_matrix()
matrices.append(rotation_matrix)

return matrices, weights


def random_rotation(prng, improper=True):
"""Generates a (uniform) random rotation matrix"""

Expand Down

0 comments on commit 488239d

Please sign in to comment.