Skip to content

Commit

Permalink
improve performance and accuracy of find_optimal_shrinkage
Browse files Browse the repository at this point in the history
  • Loading branch information
biona001 committed Jan 16, 2024
1 parent 0ffcd7d commit b2ba2ee
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 32 deletions.
7 changes: 6 additions & 1 deletion src/ghostbasil_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@ function ghostbasil_parallel(
Si = result["D"][LD_keep_idx, LD_keep_idx]
Σi = result["Sigma"][LD_keep_idx, LD_keep_idx]
zscore_tmp = @view(zscores[GWAS_keep_idx])

# shrinkage for consistency (only use reps for better speed)
t21 += @elapsed if LD_shrinkage
γ = find_optimal_shrinkage(Σi, zscore_tmp)
Σi_reps = result["Sigma_reps"]
zscore_tmp_reps = zscore_tmp[result["group_reps"]]
γ = find_optimal_shrinkage(Σi_reps, zscore_tmp_reps)
γ_mean += γ
Σi = (1 - γ)*Σi + γ*I
Si = (1 - γ)*Si + (m+1)/m*γ*I
end

# sample ghost knockoffs knockoffs
Random.seed!(seed)
t22 += @elapsed Zko_train = sample_mvn_efficient(Σi, Si, m + 1)
Expand Down
20 changes: 8 additions & 12 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,19 @@ function find_matching_indices(a::AbstractVector, b::AbstractVector)
end

# eq 24 of https://journals.plos.org/plosgenetics/article?id=10.1371/journal.pgen.1010299
function neg_mvn_logl_under_null(evals, evecs, z::AbstractVector, γ::Number)
z_scaled = evecs' * z
evals_scaled = (1-γ) .* evals .+ γ
return sum(log.(evals_scaled)) + dot(z_scaled, Diagonal(1 ./ evals_scaled), z_scaled)
end
function find_optimal_shrinkage::AbstractMatrix, z::AbstractVector)
evals, evecs = eigen(Symmetric(Σ))
opt = optimize(
γ -> neg_mvn_logl_under_null(Σ, z, γ),
0, 1.0, Brent(), show_trace=false,
iterations = 50,
γ -> neg_mvn_logl_under_null(evals, evecs, z, γ),
0, 1.0, Brent()
)
return opt.minimizer
end
function neg_mvn_logl_under_null::AbstractMatrix, z::AbstractVector, γ::Number)
return neg_mvn_logl_under_null((1-γ)*Σ + γ*I, z)
end
function neg_mvn_logl_under_null::AbstractMatrix, z::AbstractVector)
L = cholesky(Symmetric(Σ))
u = zeros(length(z))
ldiv!(u, UpperTriangular(L.factors)', z) # non-allocating ldiv!(u, L.L, z)
return 0.5logdet(L) + dot(u, u)
end

# counts number of Z scores that can be matched to LD panel
# ~400 seconds is running on typed SNPs only
Expand Down
33 changes: 14 additions & 19 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,21 @@ using DataFrames
end

@testset "LD shrinkage" begin
function mvn_logl_under_null_naive::AbstractMatrix, z::AbstractVector)
return -0.5logdet(Symmetric(Σ)) - dot(z, inv(Symmetric(Σ)), z)
end
function mvn_logl_under_null(L::Cholesky, z::AbstractVector, u=zeros(length(z)))
ldiv!(u, UpperTriangular(L.factors)', z)
return -0.5logdet(L) - dot(u, u)
end

p = 1000
z = randn(p)
x = randn(p, p)
storage=zeros(length(z))
Sigma = Symmetric(x'*x)
L = cholesky(Symmetric(Sigma));
a = mvn_logl_under_null_naive(Sigma, z)
b = mvn_logl_under_null(L, z)
@test a b

γ = GhostKnockoffGWAS.find_optimal_shrinkage(Sigma, z)
Σ = [1.0 -0.06603158126487506 -0.0833880805108651 -0.07423250848805772 -0.08520692260952517 -0.3018926160412726 -0.07228886018105062 -0.05163600016697706 -0.06831703956672884 0.1119291270044286;
-0.06603158126487506 1.0 -0.025855482863549748 -0.023593031867989712 -0.024785186697126534 0.17970487929732645 -0.0217395069925557 -0.015763968744450328 -0.0248758148261822 -0.01685099072623645;
-0.0833880805108651 -0.025855482863549748 1.0 -0.02953703088587616 -0.036090226242879775 -0.13347090146200266 -0.027912450562371478 -0.018630494841470058 -0.02981323926900486 -0.02095834431260501;
-0.07423250848805772 -0.023593031867989712 -0.02953703088587616 1.0 -0.03235265862653292 -0.12511846555868517 -0.026570109459531203 -0.0178496231359085 -0.027896176152173473 -0.01715817220843155;
-0.08520692260952517 -0.024785186697126534 -0.036090226242879775 -0.03235265862653292 1.0 0.2459334205306294 -0.030729029543925132 -0.023387666784556588 -0.03616407173066363 -0.02240494855276279;
-0.3018926160412726 0.17970487929732645 -0.13347090146200266 -0.12511846555868517 0.2459334205306294 1.0 0.19677801401533077 -0.04917243296590054 0.21936244494341658 -0.0015389401383875156;
-0.07228886018105062 -0.0217395069925557 -0.027912450562371478 -0.026570109459531203 -0.030729029543925132 0.19677801401533077 1.0 -0.020079411765085316 -0.027283609452267554 -0.019604609586933476;
-0.05163600016697706 -0.015763968744450328 -0.018630494841470058 -0.0178496231359085 -0.023387666784556588 -0.04917243296590054 -0.020079411765085316 1.0 -0.020149157087303162 -0.013591443637752606;
-0.06831703956672884 -0.0248758148261822 -0.02981323926900486 -0.027896176152173473 -0.03616407173066363 0.21936244494341658 -0.027283609452267554 -0.020149157087303162 1.0 -0.015041050757637644;
0.1119291270044286 -0.01685099072623645 -0.02095834431260501 -0.01715817220843155 -0.02240494855276279 -0.0015389401383875156 -0.019604609586933476 -0.013591443637752606 -0.015041050757637644 1.0]
z = [-0.0830535851088237, -2.25068631502435, 1.49776720532444, 1.65978366398296, 1.06042628327983, -1.41526767195536, -0.522515227438188, -0.140465835528732, 0.484580620845268, -0.981466593482507]
γ = GhostKnockoffGWAS.find_optimal_shrinkage(Σ, z)
γtrue = 6.149509305168035e-15
@test 0 γ 1
@test γ γtrue
end

@testset "ghostbasil C++ solver" begin
Expand Down

0 comments on commit b2ba2ee

Please sign in to comment.