Skip to content

Commit

Permalink
Refactor ZrnHeAgeSpherical to store and resuse damage matrices, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
brenhinkeller committed Nov 12, 2022
1 parent 12cc2cd commit 98f1372
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 285 deletions.
147 changes: 76 additions & 71 deletions examples/ZrnHeInversionVartCryst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@
dTmax = 10.0 # Maximum reheating/burial per model timestep (to prevent numerical overflow)

# Other model parameters
CrystAgeMax_Ma = 4000.0 # Ma -- forbid anything older than this
TCryst = 400.0 # Temperature (in C)
CrystAgeMax = 4000.0 # Ma -- forbid anything older than this
TStart = 400.0 # Temperature (in C)
dr = 1 # Radius step, in microns

diffusionparams = (;
Expand Down Expand Up @@ -82,58 +82,61 @@

# Populate local variables from data frame with specified options
Halfwidth = data.Halfwidth_um
U_ppm = data.U238_ppm
Th_ppm = data.Th232_ppm
HeAge_Ma = data.HeAge_Ma_raw
HeAge_Ma_sigma = data.HeAge_Ma_sigma_raw
CrystAge_Ma = data.CrystAge_Ma
Uppm = data.U238_ppm
Thppm = data.Th232_ppm
HeAge = data.HeAge_Ma_raw
HeAge_sigma = data.HeAge_Ma_sigma_raw
CrystAge = data.CrystAge_Ma

CrystAge_Ma[CrystAge_Ma .> CrystAgeMax_Ma] .= CrystAgeMax_Ma
tCryst = ceil(maximum(CrystAge_Ma)/dt) * dt
CrystAge[CrystAge .> CrystAgeMax] .= CrystAgeMax
tStart = ceil(maximum(CrystAge)/dt) * dt

tSteps = Array{Float64,1}(0+dt/2 : dt : tCryst-dt/2)
tSteps = Array{Float64,1}(0+dt/2 : dt : tStart-dt/2)
ageSteps = reverse(tSteps)
ntSteps = length(tSteps) # Number of time steps
eU = U_ppm+.238*Th_ppm # Used only for plotting
eU = Uppm+.238*Thppm # Used only for plotting

## --- Test proscribed t-T paths with Neoproterozoic exhumation step

# Generate T path to test
# # Generate T path to test
# Tr = 150
# T0 = 30
# agePoints = Float64[tCryst, tCryst*29/30, 720, 580, 250, 0] # Age (Ma)
# TPoints = Float64[TCryst, Tr+T0, Tr+T0, T0, 70, 10] # Temp. (C)
# agePoints = Float64[tStart, tStart*29/30, 720, 580, 250, 0] # Age (Ma)
# TPoints = Float64[TStart, Tr+T0, Tr+T0, T0, 70, 10] # Temp. (C)
# TSteps = linterp1s(agePoints,TPoints,ageSteps)

# Plot t-T path
#plot(ageSteps,TSteps,xflip=true)

# Calculate model ages
#CalcHeAges = Array{Float64}(undef, size(HeAge_Ma))
#@time pr = DamageAnnealing(dt,tSteps,TSteps)
#@time for i=1:length(Halfwidth)
# first_index = 1 + floor(Int64,(tCryst - CrystAge_Ma[i])/dt)
# CalcHeAges[i] = ZrnHeAgeSpherical(dt,ageSteps[first_index:end],TSteps[first_index:end],pr[first_index:end,first_index:end],Halfwidth[i],dr,U_ppm[i],Th_ppm[i],diffusionparams)
#end

# Plot Comparison of results
#p2 = plot(eU, CalcHeAges, seriestype=:scatter,label="Model")
#plot!(p2, eU, HeAge_Ma, yerror=HeAge_Ma_sigma*2, seriestype=:scatter, label="Data")
#xlabel!(p2,"eU"); ylabel!(p2,"Age (Ma)")
#display(p2)

# Check log likelihood
#ll = sum(-(CalcHeAges - HeAge_Ma).^2 ./ (2 .* HeAge_Ma_sigma.^2))
#ll = sum(-(CalcHeAges - HeAge_Ma).^2 ./ (2 .* AnnealedSigma.^2))
#
# # Plot t-T path
# plot(ageSteps,TSteps,xflip=true)
#
# # Calculate model ages
# CalcHeAges = Array{Float64}(undef, size(HeAge))
# @time pr = DamageAnnealing(dt,tSteps,TSteps)
# zircons = Array{Zircon{Float64}}(undef, length(Halfwidth))
# @time for i=1:length(zircons)
# # Iterate through each grain, calculate the modeled age for each
# first_index = 1 + floor(Int64,(tStart - CrystAge[i])/dt)
# zircons[i] = Zircon(Halfwidth[i], dr, Uppm[i], Thppm[i], dt, ageSteps[first_index:end])
# CalcHeAges[i] = HeAgeSpherical(zircons[i], @views(TSteps[first_index:end]), @views(pr[first_index:end,first_index:end]), diffusionparams)
# end
#
# # Plot Comparison of results
# p2 = plot(eU, CalcHeAges, seriestype=:scatter,label="Model")
# plot!(p2, eU, HeAge, yerror=HeAge_sigma*2, seriestype=:scatter, label="Data")
# xlabel!(p2,"eU"); ylabel!(p2,"Age (Ma)")
# display(p2)
#
# # Check log likelihood
# ll = sum(-(CalcHeAges - HeAge).^2 ./ (2 .* HeAge_sigma.^2))
# ll = sum(-(CalcHeAges - HeAge).^2 ./ (2 .* AnnealedSigma.^2))


## --- Invert for maximum likelihood t-T path

# Boundary conditions (10C at present and 650 C at the time of zircon
# formation). Optional: Apppend required unconformities at base of Cambrian,
# or elsewhere
boundary_agePoints = Float64[0, tCryst] # Ma
boundary_TPoints = Float64[0, TCryst] # Degrees C
boundary_agePoints = Float64[0, tStart] # Ma
boundary_TPoints = Float64[0, TStart] # Degrees C

# This is where the "transdimensional" part comes in
nPoints = 0
Expand All @@ -143,7 +146,7 @@

# Fill some intermediate points to give the MCMC something to work with
Tr = 250 # Residence temperature of initial proposal (value should not matter too much)
for t in [tCryst/30,tCryst/4,tCryst/2,tCryst-tCryst/4,tCryst-tCryst/30]
for t in [tStart/30,tStart/4,tStart/2,tStart-tStart/4,tStart-tStart/30]
global nPoints += 1
agePoints[nPoints] = t # Ma
TPoints[nPoints] = Tr # Degrees C
Expand All @@ -152,26 +155,28 @@
# # (Optional) Start with something close to the expected path
#Tr = 150
#T0 = 30
#agePoints[1:4] = Float64[tCryst*29/30, 720, 510, 250]) # Age (Ma)
#agePoints[1:4] = Float64[tStart*29/30, 720, 510, 250]) # Age (Ma)
#TPoints[1:4] = Float64[ Tr+T0, Tr+T0, T0, 70]) # Temp. (C)
#nPoints+=4

function MCMC_vartcryst(nPoints, maxPoints, agePoints, TPoints, unconf_agePoints, unconf_TPoints, boundary_agePoints, boundary_TPoints, simannealparams, diffusionparams)
# Calculate model ages for initial proposal
TSteps = linterp1s([view(agePoints, 1:nPoints) ; boundary_agePoints ; unconf_agePoints],
[view(TPoints, 1:nPoints) ; boundary_TPoints ; unconf_TPoints], ageSteps)
CalcHeAges = Array{Float64}(undef, size(HeAge_Ma))
CalcHeAges = Array{Float64}(undef, size(HeAge))
pr = DamageAnnealing(dt,tSteps,TSteps) # Damage annealing history
for i=1:length(Halfwidth)
zircons = Array{Zircon{Float64}}(undef, length(Halfwidth))
for i=1:length(zircons)
# Iterate through each grain, calculate the modeled age for each
first_index = 1 + floor(Int64,(tCryst - CrystAge_Ma[i])/dt)
CalcHeAges[i] = ZrnHeAgeSpherical(dt,ageSteps[first_index:end],TSteps[first_index:end],pr[first_index:end,first_index:end],Halfwidth[i],dr,U_ppm[i],Th_ppm[i], diffusionparams)
first_index = 1 + floor(Int64,(tStart - CrystAge[i])/dt)
zircons[i] = Zircon(Halfwidth[i], dr, Uppm[i], Thppm[i], dt, ageSteps[first_index:end])
CalcHeAges[i] = HeAgeSpherical(zircons[i], @views(TSteps[first_index:end]), @views(pr[first_index:end,first_index:end]), diffusionparams)
end
AnnealedSigma = simannealsigma.(1, HeAge_Ma_sigma; params=simannealparams)
UnAnnealedSigma = simannealsigma.(nsteps, HeAge_Ma_sigma; params=simannealparams)
AnnealedSigma = simannealsigma.(1, HeAge_sigma; params=simannealparams)
UnAnnealedSigma = simannealsigma.(nsteps, HeAge_sigma; params=simannealparams)

# Log-likelihood for initial proposal
ll = normpdf_ll(HeAge_Ma, AnnealedSigma, CalcHeAges)
ll = normpdf_ll(HeAge, AnnealedSigma, CalcHeAges)
if simplified
ll -= log(nPoints)
end
Expand All @@ -188,15 +193,15 @@
CalcHeAgesₚ = similar(CalcHeAges)

# Distributions to populate
HeAgeDist = Array{Float64}(undef, length(HeAge_Ma), nsteps)
HeAgeDist = Array{Float64}(undef, length(HeAge), nsteps)
TStepDist = Array{Float64}(undef, ntSteps, nsteps)
llDist = Array{Float64}(undef, nsteps)
nDist = zeros(Int, nsteps)
acceptanceDist = zeros(Bool, nsteps)

# Standard deviations of Gaussian proposal distributions for temperature and time
t_sigma = tCryst/60
T_sigma = TCryst/60
t_sigma = tStart/60
T_sigma = TStart/60

# Proposal probabilities (must sum to 1)
move = 0.64
Expand Down Expand Up @@ -227,17 +232,17 @@
if agePointsₚ[k] < dt
# Don't let any point get too close to 0
agePointsₚ[k] += (dt - agePointsₚ[k])
elseif agePointsₚ[k] > (tCryst - dt)
# Don't let any point get too close to tCryst
agePointsₚ[k] -= (agePointsₚ[k] - (tCryst - dt))
elseif agePointsₚ[k] > (tStart - dt)
# Don't let any point get too close to tStart
agePointsₚ[k] -= (agePointsₚ[k] - (tStart - dt))
end
# Move the Temperature of one model point
if TPointsₚ[k] < 0
# Don't allow T<0
TPointsₚ[k] = 0
elseif TPointsₚ[k] > TCryst
# Don't allow T>TCryst
TPointsₚ[k] = TCryst
elseif TPointsₚ[k] > TStart
# Don't allow T>TStart
TPointsₚ[k] = TStart
end

# Interpolate proposed t-T path
Expand All @@ -255,8 +260,8 @@
# Birth: add a new model point
nPointsₚ += 1
for i=1:maxattempts # Try maxattempts times to satisfy the reheating rate limit
agePointsₚ[nPointsₚ] = rand()*tCryst
TPointsₚ[nPointsₚ] = rand()*TCryst
agePointsₚ[nPointsₚ] = rand()*tStart
TPointsₚ[nPointsₚ] = rand()*TStart

# Interpolate proposed t-T path
TStepsₚ = linterp1s([view(agePointsₚ, 1:nPointsₚ) ; boundary_agePoints ; unconf_agePointsₚ],
Expand Down Expand Up @@ -288,8 +293,8 @@
# Allow the present temperature to vary from 0 to 10 degrees C
boundary_TPointsₚ[1] = 0+rand()*10
else
# Allow the initial temperature to vary from TCryst to TCryst-50 C
boundary_TPointsₚ[2] = TCryst-rand()*50
# Allow the initial temperature to vary from TStart to TStart-50 C
boundary_TPointsₚ[2] = TStart-rand()*50
end
if length(unconf_agePointsₚ) > 0
# If there's an imposed unconformity, adjust within parameters
Expand Down Expand Up @@ -317,15 +322,15 @@

# Calculate model ages for each grain
DamageAnnealing!(pr, dt, tSteps, TStepsₚ)
for i=1:length(Halfwidth)
first_index = 1 + floor(Int64,(tCryst - CrystAge_Ma[i])/dt)
@views CalcHeAgesₚ[i] = ZrnHeAgeSpherical(dt,ageSteps[first_index:end],TStepsₚ[first_index:end],pr[first_index:end,first_index:end],Halfwidth[i],dr,U_ppm[i],Th_ppm[i],diffusionparams)
for i=1:length(zircons)
first_index = 1 + floor(Int64,(tStart - CrystAge[i])/dt)
CalcHeAgesₚ[i] = HeAgeSpherical(zircons[i], @views(TStepsₚ[first_index:end]), @views(pr[first_index:end,first_index:end]), diffusionparams)
end

# Calculate log likelihood of proposal
AnnealedSigma .= simannealsigma.(n, HeAge_Ma_sigma; params=simannealparams)
llₚ = normpdf_ll(HeAge_Ma, AnnealedSigma, CalcHeAgesₚ)
llₗ = normpdf_ll(HeAge_Ma, AnnealedSigma, CalcHeAges) # Recalulate last one too with new AnnealedSigma
AnnealedSigma .= simannealsigma.(n, HeAge_sigma; params=simannealparams)
llₚ = normpdf_ll(HeAge, AnnealedSigma, CalcHeAgesₚ)
llₗ = normpdf_ll(HeAge, AnnealedSigma, CalcHeAges) # Recalulate last one too with new AnnealedSigma
if simplified # slightly penalize more complex t-T paths
llₚ -= log(nPointsₚ)
llₗ -= log(nPoints)
Expand All @@ -351,7 +356,7 @@
end

# Record results for analysis and troubleshooting
llDist[n] = normpdf_ll(HeAge_Ma, UnAnnealedSigma, CalcHeAges) # Recalculated to constant baseline
llDist[n] = normpdf_ll(HeAge, UnAnnealedSigma, CalcHeAges) # Recalculated to constant baseline
nDist[n] = nPoints # Distribution of # of points
HeAgeDist[:,n] = CalcHeAges # Distribution of He ages

Expand All @@ -366,11 +371,11 @@
@time (TStepDist, HeAgeDist, nDist, llDist, acceptanceDist) = MCMC_vartcryst(nPoints, maxPoints, agePoints, TPoints, unconf_agePoints, unconf_TPoints, boundary_agePoints, boundary_TPoints, simannealparams, diffusionparams)

# # Save results using JLD
@save string(name, ".jld") ageSteps tSteps TStepDist burnin nsteps TCryst tCryst
@save string(name, ".jld") ageSteps tSteps TStepDist burnin nsteps TStart tStart

# Plot sample age-eU correlations
h = scatter(eU,HeAgeDist[:,burnin:burnin+50], label="")
plot!(h, eU, HeAge_Ma, yerror=HeAge_Ma_sigma, seriestype=:scatter, label="Data")
plot!(h, eU, HeAge, yerror=HeAge_sigma, seriestype=:scatter, label="Data")
xlabel!(h,"eU (ppm)"); ylabel!(h,"Age (Ma)")
savefig(h,string(name,"Age-eU.pdf"))

Expand All @@ -381,14 +386,14 @@

# Resize the post-burnin part of the stationary distribution
TStepDistResized = Array{Float64}(undef, 2001, size(TStepDist,2)-burnin)
xq = collect(range(0,tCryst,length=2001))
xq = collect(range(0,tStart,length=2001))
for i=1:size(TStepDist,2)-burnin
TStepDistResized[:,i] = linterp1s(tSteps,TStepDist[:,i+burnin],xq)
end

# Calculate composite image
tTimage = zeros(ceil(Int, TCryst)*2, size(TStepDistResized,1))
yq = collect(0:0.5:TCryst)
tTimage = zeros(ceil(Int, TStart)*2, size(TStepDistResized,1))
yq = collect(0:0.5:TStart)
for i=1:size(TStepDistResized,1)
hist = fit(Histogram,TStepDistResized[i,:],yq,closed=:right)
tTimage[:,i] = hist.weights
Expand Down
1 change: 1 addition & 0 deletions src/Thermochron.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Thermochron
using LoopVectorization
using ProgressMeter: @showprogress

include("minerals.jl")
include("ZrnHe.jl")
include("inversion.jl")

Expand Down
Loading

0 comments on commit 98f1372

Please sign in to comment.