Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed May 20, 2024
1 parent d7d1806 commit ff17529
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 37 deletions.
5 changes: 3 additions & 2 deletions misc/dev_float32.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Dolo

root_dir = pkgdir(Dolo)
model = include("$(root_dir)/examples/ymodels/rbc_mc.jl")
model32 = include("$(root_dir)/misc/rbc_float32.jl")

dm = Dolo.discretize(model, Dict(:endo=>[10000000]) )

dm32 = Dolo.discretize(model, Dict(:endo=>[10000000]) )

Dolo.convert
41 changes: 33 additions & 8 deletions misc/rbc_float32.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,47 @@ model = let
end


dmodel = Dolo.discretize(model)
# now these are orphans
model32 = Dolo.convert_precision(Float32,model)


model32 = Dolo.convert_precision(Float32,model)
dmodel32 = Dolo.discretize(model32)
function Dolo.transition(model::typeof(model32), s::NamedTuple, x::NamedTuple, M::NamedTuple)

(;δ, ρ) = model.calibration

# Z = e.Z
K = s.k * (1-δ) + x.i

wksp = Dolo.time_iteration_workspace(dmodel32)
(;k=K,) ## This is only the endogenous state

itps = wksp.φ.itp
end

(;x0,r0, φ) = wksp

Dolo.F(dmodel32, x0, φ)

function intermediate(model::typeof(model32),s::NamedTuple, x::NamedTuple)

p = model.calibration

y = exp(s.z)*(s.k^p.α)*(x.n^(1-p.α))
w = (1-p.α)*y/x.n
rk = p.α*y/s.k
c = y - x.i
return ( (; y, c, rk, w))

end


function arbitrage(model::typeof(model32), s::NamedTuple, x::NamedTuple, S::NamedTuple, X::NamedTuple)

p = model.calibration

y = intermediate(model, s, x)
Y = intermediate(model, S, X)
res_1 = p.χ*(x.n^p.η)*(y.c^p.σ) - y.w
res_2 = (p.β*(y.c/Y.c)^p.σ)*(1 - p.δ + Y.rk) - 1

return ( (;res_1, res_2) )

end

Dolo.time_iteration(model)
model32
10 changes: 8 additions & 2 deletions src/algos/time_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ function time_iteration(model::DYModel,
lam = 0.5

local η_0 = NaN
local η
convergence = false
iterations = T

Expand Down Expand Up @@ -246,6 +247,7 @@ function time_iteration(model::DYModel,

ε_n = norm(r0)
if ε_n<tol_ε
iterations = t
break
end

Expand All @@ -261,6 +263,7 @@ function time_iteration(model::DYModel,

ε_b = norm(r0)
if ε_b<ε_n
iterations = t
break
end
end
Expand All @@ -280,6 +283,10 @@ function time_iteration(model::DYModel,

verbose ? append!(log; verbose=verbose, it=t-1, err=ε, sa=η_0, lam=gain, elapsed=elapsed) : nothing

if η < tol_η
iterations = t
break
end
η_0 = η


Expand Down Expand Up @@ -314,7 +321,7 @@ function time_iteration(model::DYModel,
φ,
iterations,
tol_η,
η_0,
η,
log,
ti_trace
)
Expand All @@ -329,7 +336,6 @@ function newton(model, workspace=newton_workspace(model);

(;x0, x1, x2, dx, r0, J, φ, T, memn) = workspace


for t=1:K

Dolo.fit!(φ, x0)
Expand Down
6 changes: 3 additions & 3 deletions src/funs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ end

# Compatibility calls

(f::DFun)(x::Float64) = f(SVector(x))
(f::DFun)(x::Float64, y::Float64) = f(SVector(x,y))
(f::DFun)(x::Vector{SVector{d,Float64}}) where d = [f(e) for e in x]
(f::DFun)(x::Real) = f(SVector(x))
(f::DFun)(x::Real, y::Real) = f(SVector(x,y))
(f::DFun)(x::Vector{SVector{d,<:Real}}) where d = [f(e) for e in x]


ndims(df::DFun) = ndims(df.domain)
Expand Down
18 changes: 10 additions & 8 deletions src/garray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ struct GArray{G,U}
end

const GVector{G,T} = GArray{G,T}
const GDist{G} = GArray{G, Vector{Float64}}
const GDist{G,Tf} = GArray{G, Vector{Tf}}

GDist(g::G, v::Vector{Float64}) where G= GArray{G,Vector{Float64}}(g, v)
GDist(g::G, v::Vector{Tf}) where G where Tf = GArray{G,Vector{Tf}}(g, v)


norm(v::GArray) = maximum(u->maximum(abs, u), v.data)
Expand Down Expand Up @@ -45,7 +45,10 @@ end
eltype(g::GArray{G,T}) where G where T = eltype(T)

# warning: these functions don't copy any data
ravel(g::GArray) = reinterpret(Float64, g.data)
function ravel(g::GArray)
Tf = eltype(g.data[1])
reinterpret(Tf, g.data)
end
unravel(g::GArray, x) = GArray(
g.grid,
reinterpret(eltype(g), x)
Expand Down Expand Up @@ -113,22 +116,21 @@ import Base: *, \, +, -, /
*(x::Number, A::GArray{G,T}) where G where T = GArray(A.grid, x .* A.data)


*(A::GArray{G,Vector{T}}, x::SVector{q, Float64}) where G where T <:SMatrix{p, q, Float64, n} where p where q where n = GArray(A.grid, [M*x for M in A.data])
# *(A::GArray{G,Vector{T}}, x::SLArray{Tuple{q}, Float64, 1, q, names}) where G where T <:SMatrix{p, q, Float64, n} where p where q where n where names = A*SVector(x...)
*(A::GArray{G,Vector{T}}, x::SVector{q, Tf}) where G where T <:SMatrix{p, q, Tf, n} where p where q where n where Tf = GArray(A.grid, [M*x for M in A.data])


*(A::GArray{G,T}, B::AbstractArray{Float64}) where G where T <:SMatrix{p, q, Float64, n} where p where q where n =
*(A::GArray{G,T}, B::AbstractArray{Tf}) where G where T <:SMatrix{p, q, Tf, n} where p where q where n where Tf =
ravel(
GArray(
A.grid,
A.data .* reinterpret(SVector{q, Float64}, B)
A.data .* reinterpret(SVector{q, Tf}, B)
)
)


import Base: convert

function Base.convert(::Type{Matrix}, A::GArray{G,Vector{T}}) where G where T <:SMatrix{p, q, Float64, k} where p where q where k
function Base.convert(::Type{Matrix}, A::GArray{G,Vector{T}}) where G where T <:SMatrix{p, q, Tf, k} where p where q where k where Tf
N = length(A.data)
n0 = N*p
n1 = N*q
Expand Down
15 changes: 10 additions & 5 deletions src/grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ abstract type ASGrid{d} <: AGrid{d} end

import Base: eltype, iterate, size

eltype(cg::AGrid{d}) where d = SVector{d, Float64}
# eltype(cg::AGrid{d}) where d = SVector{d, Float64}
ndims(cg::AGrid{d}) where d = d

struct CGrid{d,Tf} <: AGrid{d}
Expand Down Expand Up @@ -90,7 +90,8 @@ from_linear(g::PGrid{G1, G2, d}, n) where G1 where G2 where d = let x=divrem(n-1
getindex(g::PGrid{G1, G2, d}, n::Int) where G1 where G2 where d = getindex(g, from_linear(g, n)...)

function getindex(g::PGrid{G1, G2, d}, i::Int64, j::Int64) where G1<:SGrid{d1} where G2<:CGrid{d2} where d where d1 where d2
SVector{d,Float64}(g.grids[1][i]..., g.grids[2][j]...)
Tf = eltype(g)
SVector{d,Tf}(g.grids[1][i]..., g.grids[2][j]...)
end


Expand Down Expand Up @@ -177,17 +178,21 @@ end


function Base.iterate(g::PGrid{G1, G2, d}) where G1 where G2 where d
T = eltype(g)
x = g.grids[1][1]
y = g.grids[2][1]
return (SVector{d, Float64}(x...,y...),(y,1,1))
return (SVector{d, T}(x...,y...),(y,1,1))
end

function Base.iterate(g::PGrid{G1,G2,d},state) where G1 where G2 where d

T = eltype(g)

y,i,j=state
if i<length(g.grids[1])
i += 1
x = g.grids[1][i]
return (SVector{d,Float64}(x..., y...), (y,i,j))
return (SVector{d,T}(x..., y...), (y,i,j))
else
if j==length(g.grids[2])
return nothing
Expand All @@ -196,7 +201,7 @@ function Base.iterate(g::PGrid{G1,G2,d},state) where G1 where G2 where d
i = 1
x = g.grids[1][i]
y = g.grids[2][j]
return (SVector{d,Float64}(x..., y...), (y,i,j))
return (SVector{d,T}(x..., y...), (y,i,j))
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/splines/interp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using StaticArrays
import LoopVectorization: VectorizationBase
import Base: getindex

getindex(A::Vector{Float64}, i::VectorizationBase.Vec{4,Int64}) = VectorizationBase.Vec{4, Float64}(A[i(1)], A[i(2)], A[i(3)], A[i(4)])
getindex(A::Vector{Tf}, i::VectorizationBase.Vec{4,Int64}) where Tf = VectorizationBase.Vec{4, Tf}(A[i(1)], A[i(2)], A[i(3)], A[i(4)])


# ## TODO : rewrite the following
Expand Down Expand Up @@ -44,7 +44,7 @@ matextract(v::AbstractArray{T,1}, i) where T = SArray{Tuple{2}, T, 1, 2}(
v[i+1]
)

function interp(ranges::NTuple{d, Tuple{Float64, Float64, Int64}}, values::AbstractArray{T,d}, x::SVector{d, U}) where d where T where U
function interp(ranges::NTuple{d, Tuple{Tf, Tf, Int64}}, values::AbstractArray{T,d}, x::SVector{d, U}) where d where T where U where Tf

a = SVector( (e[1] for e in ranges)... )
b = SVector( (e[2] for e in ranges)... )
Expand All @@ -67,7 +67,7 @@ function interp(ranges::NTuple{d, Tuple{Float64, Float64, Int64}}, values::Abstr

end

function interp(ranges::NTuple{d, Tuple{Float64, Float64, Int64}}, values::AbstractArray{T,d}, x::Vararg{U}) where d where T where U
function interp(ranges::NTuple{d, Tuple{Tf, Tf, Int64}}, values::AbstractArray{T,d}, x::Vararg{U}) where d where T where U where Tf
xx = SVector(x...)
interp(ranges, values, xx)
end
12 changes: 6 additions & 6 deletions src/splines/splines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ function interpolant_cspline(a, b, orders, V)

coefs = filter_coeffs(a, b, orders, V)

function fun(s::Array{Float64,2})
function fun(s::Array{Tf,2}) where Tf
return eval_UC_spline(a, b, orders, coefs, s)
end

function fun(p::Float64...)
function fun(p::Tf...) where Tf
return fun([p...]')
end

Expand All @@ -45,7 +45,7 @@ end



function prefilter(ranges::NTuple{d,Tuple{Float64,Float64,i}}, V::AbstractArray{T, d}, ::Val{3}) where d where i<:Int where T
function prefilter(ranges::NTuple{d,Tuple{Tf,Tf,i}}, V::AbstractArray{T, d}, ::Val{3}) where d where i<:Int where T where Tf
θ = zeros(eltype(V), ((e[3]+2) for e in ranges)...)
ind = tuple( (2:(e[3]+1) for e in ranges )...)
θ[ind...] = V
Expand All @@ -54,17 +54,17 @@ end
end


function prefilter!::AbstractArray{T, d}, grid::NTuple{d,Tuple{Float64,Float64,i}}, V::AbstractArray{T, d}, ::Val{3}) where d where i<:Int where T
function prefilter!::AbstractArray{T, d}, grid::NTuple{d,Tuple{Tf,Tf,i}}, V::AbstractArray{T, d}, ::Val{3}) where d where i<:Int where T where Tf
splines.prefilter!(θ)
end

function prefilter(ranges::NTuple{d,Tuple{Float64,Float64,i}}, V::AbstractArray{T, d}, ::Val{1}) where d where i<:Int where T
function prefilter(ranges::NTuple{d,Tuple{Tf,Tf,i}}, V::AbstractArray{T, d}, ::Val{1}) where d where i<:Int where T where Tf
θ = copy(V)
return θ
end


function prefilter!::AbstractArray{T, d}, grid::NTuple{d,Tuple{Float64,Float64,i}}, V::AbstractArray{T, d}, ::Val{1}) where d where i<:Int where T
function prefilter!::AbstractArray{T, d}, grid::NTuple{d,Tuple{Tf,Tf,i}}, V::AbstractArray{T, d}, ::Val{1}) where d where i<:Int where T where Tf
θ .= V
end

Expand Down

0 comments on commit ff17529

Please sign in to comment.