Skip to content

Commit

Permalink
Merge pull request #2 from cesmix-mit/refactor
Browse files Browse the repository at this point in the history
Merge refactor
  • Loading branch information
joannajzou authored Jul 11, 2024
2 parents e7a0b62 + a9c6794 commit cbae89e
Show file tree
Hide file tree
Showing 43 changed files with 2,131 additions and 863 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ version = "0.1.0"
[deps]
AtomisticQoIs = "895e25ce-6034-4689-a3ba-4ac45d83446c"
AtomsBase = "a963bdd2-2df7-4f54-a1ee-49d51e6be12a"
AtomsCalculators = "a3e0e189-c65a-42c1-833c-339540406eb1"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
InteratomicPotentials = "a9efe35a-c65d-452d-b8a8-82646cd5cb04"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Maxvol = "4cc553b9-be87-484b-81d9-b5ae2a4e3958"
Molly = "aa0f7f06-fcc0-5ec4-a7f3-a573f33f9c4c"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
PotentialLearning = "82b0a93c-c2e3-44bc-a418-f0f89b0ae5c2"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SpecialPolynomials = "a25cea48-d430-424a-8ee7-0d3ad3742e9e"
Expand All @@ -28,7 +31,6 @@ UnitfulAtomic = "a7773ee8-282e-5fa2-be4e-bd808c38a91a"
[compat]
AtomsBase = "0.3"
Distributions = "0.25"
Molly = "0.18.3"
julia = "1.9"

[extras]
Expand Down
23 changes: 21 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://cesmix-mit.github.io/Cairn.jl/dev/)
[![Build Status](https://github.com/cesmix-mit/Cairn.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/cesmix-mit/Cairn.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Build Status](https://travis-ci.com/cesmix-mit/Cairn.jl.svg?branch=main)](https://travis-ci.com/cesmix-mit/Cairn.jl)
[![Build Status](https://ci.appveyor.com/api/projects/status/github/cesmix-mit/Cairn.jl?svg=true)](https://ci.appveyor.com/project/cesmix-mit/Cairn-jl)
<!-- [![Build Status](https://ci.appveyor.com/api/projects/status/github/cesmix-mit/Cairn.jl?svg=true)](https://ci.appveyor.com/project/cesmix-mit/Cairn-jl)
[![Coverage](https://codecov.io/gh/cesmix-mit/Cairn.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/cesmix-mit/Cairn.jl)
[![Coverage](https://coveralls.io/repos/github/cesmix-mit/Cairn.jl/badge.svg?branch=main)](https://coveralls.io/github/cesmix-mit/Cairn.jl?branch=main)
[![Coverage](https://coveralls.io/repos/github/cesmix-mit/Cairn.jl/badge.svg?branch=main)](https://coveralls.io/github/cesmix-mit/Cairn.jl?branch=main) -->

Cairn.jl is a toolkit of active learning algorithms for training machine learning interatomic potentials (ML-IPs) for molecular dynamics simulation.

Cairn.jl is constructed as an extension to [Molly.jl](https://github.com/JuliaMolSim/Molly.jl), implementing enhanced MD samplers, and interfaces with other packages in the Julia ecosystem for molecular simulation, developed by [CESMIX](https://github.com/cesmix-mit) and [JuliaMolSim](https://github.com/JuliaMolSim).

Active learning algorithms build efficient training datasets which maximally improve accuracy of a scientific machine learning model. These algorithms take an iterative structure, looping through the steps:

1. **Data generation**. A system's potential energy landscape is sampled by generating trajectories of molecular configurations through the simulation of Newton's equation of motion or its modifications. Users have a choice between standard MD simulation, such as Langevin dynamics or Velocity-Verlet, or enhanced sampling algorithms, such as uncertainty driven dynamics ([UDD](https://www.nature.com/articles/s43588-023-00406-5)), Stein repulsive Langevin dynamics, or Stein variational molecular dynamics. These methods are specified under the abstract type `Simulator`.

2. **Trigger for retraining.** Sampling is terminated and retraining is triggered when the trajectory has met a certain criteria. A "fixed trigger" calls on retraining after a fixed number of simulation steps. "Adaptive triggers" are based on metrics of uncertainty, from Gaussian process or ensemble-based estimates of variance; metrics of extrapolation, based on a MaxVol vector basis; or metrics of diversity, such as a DPP inclusion probability. These methods are specified under the abstract type `ActiveLearningTrigger`.

3. **Data subset selection and labelling.** A subset of the data from the simulated trajectory is selected for labelling using reference calculations and appending to the training set. The most basic selection is a random subset of the trajectory. "Adaptive" selections can be made based on data which exceeds a threshold or data which are chosen by a subset selection algorithm, such as MaxVol, k-means clustering, or DPPs. These methods are specified under the abstract type `SubsetSelector`.

4. **Model updating.** The machine learning model is retrained on the augmented dataset according to the choice of objective function defined by the abstract type `LinearProblem`. These methods live in the package [PotentialLearning.jl](https://github.com/cesmix-mit/PotentialLearning.jl).


For a technical manual on the package, see the [docs](cesmix-mit.github.io/Cairn.jl/). For a demo, see the Jupyter notebooks in the `examples` folder.


219 changes: 219 additions & 0 deletions examples/himmelblau_train.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
using Cairn
using LinearAlgebra, Random, Statistics, StatsBase, Distributions
using PotentialLearning
using Molly, AtomsCalculators
using AtomisticQoIs
using SpecialPolynomials, SpecialFunctions

include("./src/makie/makie.jl")
include("./examples/utils.jl")



## define models ------------------------------------------------------------------
# choose reference model
ref = Himmelblau()

# define main support
limits = [[-6.5,6.5],[-6,6]]
# limits = [[-3.5,1.5],[-1.5,3.5]]
coord_grid = coord_grid_2d(limits, 0.1)
ctr_lvls = 0:25:400

# PCE properties
basisfam = Jacobi{0.5,0.5}
order = 5
pce0 = PolynomialChaos(order, 2, basisfam, xscl=limits)

# grid over main support
coords_eval = potential_grid_2d(ref, limits, 0.1, cutoff = 400)
sys_eval = define_ens(ref, coords_eval)

# use grid to define uniform quadrature points
ξ = [ustrip.(Vector(coords)) for coords in coords_eval]
GQint = GaussQuadrature(ξ, ones(length(ξ))./length(ξ))

# plot
f0, ax0 = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords_eval))'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="test points")
axislegend(ax0)
f0

# plot density
f, _ = plot_density(ref, coord_grid, GQint)


# reference: train to test set
# pce = deepcopy(pce0)
# lp = learn!(sys_eval, ref, pce, [1000,1], false; e_flag=true, f_flag=true)
# p = define_gibbs_dist(ref)
# q = define_gibbs_dist(pce, θ=lp.β)
# fish = FisherDivergence(GQint)
# fd_best = compute_divergence(p, q, fish)


## training set 1: grid over main support ---------------------------------------
# sample from grid
coords1 = potential_grid_2d(ref, limits, 0.2, cutoff = 400)
trainset1 = define_ens(deepcopy(pce0), coords1)

# plot
f0, ax0 = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords1))'
scatter!(ax0, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 1")
axislegend(ax0)
f0



## training set 2: samples from Langevin MD -------------------------------------
# Langevin simulator
sim_langevin = OverdampedLangevin(
dt=0.002u"ps",
temperature=500.0u"K",
friction=4.0u"ps^-1",
)

x0arr = [[4.5, -2], [-3.5,3], [-3.5,-3]]
sys_langevin = Vector(undef, 3)
for (i,x0) in enumerate(x0arr)
sys0 = define_sys(
ref,
x0,
loggers=(coords=CoordinateLogger(100; dims=2),),
)
# simulate
sys2 = deepcopy(sys0)
simulate!(sys2, sim_langevin, 1_000_000)
sys_langevin[i] = sys2
end


# subselect train data from the trajectory
n = [1335, 669, 669]
coords2 = [[sys_langevin[j].loggers.coords.history[i][1] for i=1:n[j]] for j=1:3]
coords2 = reduce(vcat, coords2)
trainset2 = define_ens(deepcopy(pce0), coords2)

# plot
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords2))'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 2")
axislegend(ax)
f



## training set 3: samples from high-T MD -------------------------------------
# high-temp Langevin simulator
sim_highT = OverdampedLangevin(
dt=0.002u"ps",
temperature=2000.0u"K",
friction=4.0u"ps^-1",
)
# simulate
sys3 = deepcopy(sys0)
simulate!(sys3, sim_highT, 2_000_000)
# f = plot_md_trajectory(sys3, coord_grid, fill=false, lvls=ctr_lvls, showpath=false)

# subselect train data from the trajectory
id = StatsBase.sample(1:length(sys3.loggers.coords.history), length(coords1), replace=false)
coords3 = [sys3.loggers.coords.history[i][1] for i in id]
trainset3 = define_ens(deepcopy(pce0), coords3)

# plot
f, ax = plot_contours_2d(ref, coord_grid; fill=true, lvls=ctr_lvls)
coordmat = reduce(hcat, get_values(coords3))'
scatter!(ax, coordmat[:,1], coordmat[:,2], color=:red, markersize=5, label="train set 3")
axislegend(ax)
f


# train with changing weight λ --------------------------------------------------------------
λarr = [1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4]
trainsets = [trainset1, trainset2, trainset3]
p = define_gibbs_dist(ref)
fish = FisherDivergence(GQint)


# store results
param_dict = Dict( "ts$j" => Dict(
"E" => zeros(length(pce.basis)),
"F" => zeros(length(pce.basis)),
"EF" => Vector{Vector}(undef, length(λarr)),
) for j = 1:length(trainsets)
)

err_dict = Dict( "ts$j" => Dict(
"E" => 0.0,
"F" => 0.0,
"EF" => zeros(length(λarr)),
) for j = 1:length(trainsets)
)

fd_dict = Dict( "ts$j" => Dict(
"E" => 0.0,
"F" => 0.0,
"EF" => zeros(length(λarr)),
) for j = 1:length(trainsets)
)


# train on E or F only (UnivariateLinearProblem)
for (j,ts) in enumerate(trainsets)
# E objective
println("train set $j, E only")
pce = deepcopy(pce0)
lpe = learn!(ts, ref, pce; e_flag=true, f_flag=false)
q = define_gibbs_dist(pce, θ=lpe.β)
err_dict =
fd_dict["ts$j"]["E"] = compute_divergence(p, q, fish)
param_dict["ts$j"]["E"] = lpe.β

# F objective
println("train set $j, F only")
pce = deepcopy(pce0)
lpf = learn!(ts, ref, pce; e_flag=false, f_flag=true)
q = define_gibbs_dist(pce, θ=lpf.β)
fd_dict["ts$j"]["F"] = compute_divergence(p, q, fish)
param_dict["ts$j"]["F"] = lpf.β
end

# train on EF (CovariateLinearProblem)
for (i,λ) in enumerate(λarr)
for (j,ts) in enumerate(trainsets)

# EF objective
println("train set $j, EF (λ=)")
pce = deepcopy(pce0)
lpef = learn!(ts, ref, pce, [λ, 1], false; e_flag=true, f_flag=true)
q = define_gibbs_dist(pce, θ=lpef.β)
fd_dict["ts$j"]["EF"][i] = compute_divergence(p, q, fish)
param_dict["ts$j"]["EF"][i] = lpef.β
end
end



# plot results
λlab = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 1e1, 1e2, 1e3, 1e4, 1e5]
f = Figure(resolution=(550,450))
ax = Axis(f[1,1],
xlabel="λ",
ylabel="Fisher divergence",
title="Model Error vs. Weight λ",
xscale=log10,
yscale=log10,
xticks=(λlab, ["F", "1e-4", "1e-3", "1e-2", "1e-1", "1", "1e1", "1e2", "1e3", "1e4", "E"]))

for j = 1:3
fd_all = reduce(vcat, [[fd_dict["ts$j"]["F"]], fd_dict["ts$j"]["EF"], [fd_dict["ts$j"]["E"]]])
scatterlines!(ax, λlab, fd_all, label="train set $j")
end
axislegend(ax, position=:lt)
f

pce.params = param_dict["ts2"]["E"]
ctr_lvls2 = -20:5:50 # for forces
f, _ = plot_contours_2d(pce, coord_grid, fill=true, lvls=ctr_lvls)
Loading

0 comments on commit cbae89e

Please sign in to comment.