-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrde.jl
83 lines (69 loc) · 2.3 KB
/
rde.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
using Pkg
Pkg.activate(@__DIR__)
# parse command line arguments if given
if length(ARGS) > 0
subdir = ARGS[1]
# otherwise prompt user to specify
else
print("Please enter sub directory to run RDE in: ")
subdir = readline()
end
# input validation
while !isdir(joinpath(@__DIR__, subdir))
print("Invalid directory $subdir. Please enter sub directory to run RDE in: ")
global subdir = readline()
end
using PyCall
pushfirst!(PyVector(pyimport("sys")["path"]), joinpath(@__DIR__, subdir))
import FrankWolfe
include("custom_oralces.jl")
include(joinpath(@__DIR__, subdir, "config.jl")) # load indices, rates, max_iter
cd(subdir)
# Get the Python side of RDE
rde = pyimport("rde")
for idx in indices
# Load data sample and distortion functional
x, fname = rde.get_data_sample(idx)
f, df, node, pred = rde.get_distortion(x)
# Wrap objective and gradiet functions
function func(s)
if !(s isa Vector{eltype(x)})
s = convert(Vector{eltype(x)}, s)
end
return f(s)
end
function grad!(storage, s)
if !(s isa Vector{eltype(x)})
s = convert(Vector{eltype(x)}, s)
end
g = df(s)
return @. storage = g
end
all_s = zeros(eltype(x), (length(rates), length(x)))
for rate in rates
# Run FrankWolfe
println("Running sample $idx with rate $rate")
s0 = similar(x[:])
s0 .= 0.0
lmo = NonNegKSparseLMO(rate, 1.0)
@time s, v, primal, dual_gap = FrankWolfe.frank_wolfe(
#@time s, v, primal, dual_gap = FrankWolfe.away_frank_wolfe(
#@time s, v, primal, dual_gap = FrankWolfe.blended_conditional_gradient(
#@time s, v, primal, dual_gap = FrankWolfe.lazified_conditional_gradient(
s -> func(s),
(storage, s) -> grad!(storage, s),
lmo,
s0,
;fw_arguments...
)
# reset adaptive step size if necessary
if fw_arguments.line_search isa FrankWolfe.MonotonousNonConvexStepSize
fw_arguments.line_search.factor = 0
end
# Store single rate result
all_s[indexin(rate, rates)[1], :] = s
rde.store_single_result(s, idx, fname, rate)
end
# Store multiple rate results
rde.store_collected_results(all_s, idx, node, pred, fname, rates)
end