Skip to content

Commit

Permalink
Merge Natural-Splines
Browse files Browse the repository at this point in the history
add simple natural spline fitting
  • Loading branch information
francescoalemanno authored Sep 27, 2022
2 parents c9c511b + 50d85ba commit eb43244
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KissSmoothing"
uuid = "23b0397c-cd08-4270-956a-157331f0528f"
authors = ["Francesco Alemanno <[email protected]>"]
version = "1.0.3"
version = "1.0.4"

[deps]
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
65 changes: 64 additions & 1 deletion src/KissSmoothing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,69 @@ function fit_rbf(
)
RBF(evalPhi(xv, cp) \ yv, collect(cp))
end




function basis_d(x, n1, nK)
x1p = max(x - n1, zero(x))
x2p = max(x - nK, zero(x))
return ((x1p)^3 - (x2p)^3) / (nK - n1)
end

function basis_N(x, xi, k::Int)
K = length(xi)
if k<1 || k>K
error("order must be between 1 and K = length(xi)")
end
if k == K - 1
return one(x)
end
sx = (x - xi[1]) / (xi[end] - xi[1])
if k == K
return sx
end
nxi_k = (xi[k] - xi[1]) / (xi[end] - xi[1])
nxi_em1 = (xi[end-1] - xi[1]) / (xi[end] - xi[1])
return basis_d(sx, nxi_k, 1) - basis_d(sx, nxi_em1, xi[end])
end

"""
fit_nspline(xv::Vector, yv::Vector, cp::Vector)
fit natural cubic splines basis function according to:
`xv` : array N, N number of training points
`yv` : array N, N number of training points
`cp` : array K, K number of control points
returns a callable function.
"""
function fit_nspline(
x::AbstractVector{Float64},
y::AbstractVector{Float64},
xi::AbstractVector{Float64},
)
issorted(xi) || error("Knots \"xi\" must be sorted.")
N = length(x)
K = length(xi)
M = zeros(N, K)
scal = 1 / sqrt(N)
for i = 1:N, j = 1:K
M[i, j] = basis_N(x[i], xi, j) / scal
end
C = M \ (y ./ scal)
function fn(x)
s = zero(x)
for i in eachindex(C)
s += basis_N(x, xi, i) * C[i]
end
return s
end
end


export denoise, fit_rbf, RBF
export denoise, fit_rbf, RBF, fit_nspline
end # module
12 changes: 12 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,15 @@ end
@test error < 0.0002
end
end

@testset "Fit NSpline" begin
for μ in LinRange(-100,100,5)
t = LinRange(0,2pi,150)
y = sin.(t).+μ
fn = fit_nspline(t,y,LinRange(0,2pi,50))
pred_y = fn.(t)
error = sqrt(sum(abs2, pred_y .- y)/length(t))
@test error < 0.0002
end
@test_throws ErrorException KissSmoothing.basis_N(Float64[],Float64[],1)
end

0 comments on commit eb43244

Please sign in to comment.