Skip to content
This repository has been archived by the owner on Jun 14, 2023. It is now read-only.

Commit

Permalink
Apply JuliaFormatter in preparation of #26
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmouchet committed Aug 28, 2020
1 parent 9f505cf commit c3417a1
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 32 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pkg> test HMMBase
5. Format the code using [JuliaFormatter](https://github.com/domluna/JuliaFormatter.jl):
```julia
julia> using JuliaFormatter
julia> format(".", margin=100)
julia> format(".")
```

6. Commit your changes and submit the PR!
Expand Down
2 changes: 0 additions & 2 deletions benchmark/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,3 @@ function mkbench!(f, base, path)
grp = mkgrp!(base, path[1:end-1])
grp[path[end]] = f()
end


5 changes: 3 additions & 2 deletions docs/literate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ using Literate
using Random

config = Dict(
:binder_root_url => "https://mybinder.org/v2/gh/maxmouchet/HMMBase.jl/master?filepath=",
:repo_root_url => "https://github.com/maxmouchet/HMMBase.jl/blob/master"
:binder_root_url =>
"https://mybinder.org/v2/gh/maxmouchet/HMMBase.jl/master?filepath=",
:repo_root_url => "https://github.com/maxmouchet/HMMBase.jl/blob/master",
)

function literate_documenter(inputdir, outputdir)
Expand Down
4 changes: 1 addition & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ end
makedocs(
sitename = "HMMBase",
modules = [HMMBase],
format = Documenter.HTML(
assets = ["assets/goatcounter.js"]
),
format = Documenter.HTML(assets = ["assets/goatcounter.js"]),
pages = [
"index.md",
"Manual" => ["basics.md", "models.md", "algorithms.md", "utilities.md"],
Expand Down
18 changes: 13 additions & 5 deletions src/hmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,17 @@ struct HMM{F,T} <: AbstractHMM{F}
HMM{F,T}(a, A, B) where {F,T} = assert_hmm(a, A, B) && new(a, A, B)
end

HMM(a::AbstractVector{T}, A::AbstractMatrix{T}, B::AbstractVector{<:Distribution{F}}) where {F,T} =
HMM{F,T}(a, A, B)
HMM(
a::AbstractVector{T},
A::AbstractMatrix{T},
B::AbstractVector{<:Distribution{F}},
) where {F,T} = HMM{F,T}(a, A, B)

HMM(A::AbstractMatrix{T}, B::AbstractVector{<:Distribution{F}}) where {F,T} =
HMM{F,T}(ones(size(A)[1]) / size(A)[1], A, B)

function HMM(a::AbstractVector{T}, A::AbstractMatrix{T}, B::AbstractMatrix) where {T}
B = map(i -> Categorical(B[i,:]), 1:size(B,1))
B = map(i -> Categorical(B[i, :]), 1:size(B, 1))
HMM{Univariate,T}(a, A, B)
end

Expand Down Expand Up @@ -87,7 +90,8 @@ issquare(A::AbstractMatrix) = size(A, 1) == size(A, 2)
Return true if `A` is square and its rows sums to 1.
"""
istransmat(A::AbstractMatrix) = issquare(A) && all([isprobvec(A[i, :]) for i = 1:size(A, 1)])
istransmat(A::AbstractMatrix) =
issquare(A) && all([isprobvec(A[i, :]) for i = 1:size(A, 1)])

==(h1::AbstractHMM, h2::AbstractHMM) = (h1.a == h2.a) && (h1.A == h2.A) && (h1.B == h2.B)

Expand Down Expand Up @@ -162,7 +166,11 @@ function rand(rng::AbstractRNG, hmm::AbstractHMM{Univariate}, z::AbstractVector{
y
end

function rand(rng::AbstractRNG, hmm::AbstractHMM{Multivariate}, z::AbstractVector{<:Integer})
function rand(
rng::AbstractRNG,
hmm::AbstractHMM{Multivariate},
z::AbstractVector{<:Integer},
)
y = Matrix{Float64}(undef, length(z), size(hmm, 2))
for t in eachindex(z)
y[t, :] = rand(rng, hmm.B[z[t]])
Expand Down
2 changes: 1 addition & 1 deletion src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function backward(a::AbstractVector, A::AbstractMatrix, LL::AbstractMatrix; logl
backwardlog!(m, c, a, A, LL)
m, sum(c)
end

"""
forward(hmm, observations; robust) -> (Vector, Float)
Expand Down
12 changes: 8 additions & 4 deletions test/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,22 @@ using Test
Random.seed!(2019)

@testset "< v1.1" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0,1), Normal(10,1)])
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
z, y = rand(hmm, 2500, seq = true)
@test forward(hmm, y, logl = true) == forward(hmm, y, logl = false) == forward(hmm, y)
@test backward(hmm, y, logl = true) == backward(hmm, y, logl = false) == backward(hmm, y)
@test posteriors(hmm, y, logl = true) == posteriors(hmm, y, logl = false) == posteriors(hmm, y)
@test backward(hmm, y, logl = true) ==
backward(hmm, y, logl = false) ==
backward(hmm, y)
@test posteriors(hmm, y, logl = true) ==
posteriors(hmm, y, logl = false) ==
posteriors(hmm, y)
@test viterbi(hmm, y, logl = false) == viterbi(hmm, y, logl = true) == z
@test likelihoods(hmm, y) == exp.(loglikelihoods(hmm, y))
@test likelihoods(hmm, y, logl = true) == loglikelihoods(hmm, y)
end

@testset "< v1.0" begin
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0,1), Normal(10,1)])
hmm = HMM([0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
z, y = rand(hmm, 2500, seq = true)
@test n_parameters(hmm) == nparams(hmm)
@test log_likelihoods(hmm, y) == loglikelihoods(hmm, y)
Expand Down
5 changes: 4 additions & 1 deletion test/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ Random.seed!(2019)
hmms = [
HMM([0.9 0.1; 0.1 0.9], [Normal(10, 1), Gamma(1, 1)]),
HMM([0.9 0.1; 0.1 0.9], [Categorical([0.1, 0.2, 0.7]), Categorical([0.5, 0.5])]),
HMM([0.9 0.1; 0.1 0.9], [MvNormal([0.0, 0.0], [1.0, 1.0]), MvNormal([10.0, 10.0], [1.0, 1.0])]),
HMM(
[0.9 0.1; 0.1 0.9],
[MvNormal([0.0, 0.0], [1.0, 1.0]), MvNormal([10.0, 10.0], [1.0, 1.0])],
),
]

@testset "Integration $(typeof(hmm))" for hmm in hmms, T in [0, 1, 1000]
Expand Down
14 changes: 10 additions & 4 deletions test/pyhsmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,21 @@ end
y = rand(hmm, 2500)
LL = likelihoods(hmm, y, logl = true)

ref =
pyhsmm.internals.hmm_states.HMMStatesPython._messages_forwards_normalized(hmm.A, hmm.a, LL)
ref = pyhsmm.internals.hmm_states.HMMStatesPython._messages_forwards_normalized(
hmm.A,
hmm.a,
LL,
)
res = forward(hmm.a, hmm.A, LL, logl = true)

@test all(res[1] .≈ ref[1])
@test res[2] ref[2]

ref =
pyhsmm.internals.hmm_states.HMMStatesPython._messages_backwards_normalized(hmm.A, hmm.a, LL)
ref = pyhsmm.internals.hmm_states.HMMStatesPython._messages_backwards_normalized(
hmm.A,
hmm.a,
LL,
)
res = backward(hmm.a, hmm.A, LL, logl = true)

@test all(res[1] .≈ ref[1])
Expand Down
44 changes: 35 additions & 9 deletions test/unit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ end

@testset "Base (5)" begin
# Emission matrix constructor
hmm1 = HMM([0.9 0.1; 0.1 0.9], [0. 0.5 0.5; 0.25 0.25 0.5])
hmm2 = HMM([0.9 0.1; 0.1 0.9], [Categorical([0., 0.5, 0.5]), Categorical([0.25, 0.25, 0.5])])
hmm1 = HMM([0.9 0.1; 0.1 0.9], [0.0 0.5 0.5; 0.25 0.25 0.5])
hmm2 = HMM(
[0.9 0.1; 0.1 0.9],
[Categorical([0.0, 0.5, 0.5]), Categorical([0.25, 0.25, 0.5])],
)
@test hmm1 == hmm2
end

Expand All @@ -103,13 +106,26 @@ end
# Wrong trans. matrix
@test_throws ArgumentError HMM(ones(2, 2), [Normal(); Normal()])
# Wrong trans. matrix dimensions
@test_throws ArgumentError HMM([0.8 0.1 0.1; 0.1 0.1 0.8], [Normal(0, 1), Normal(10, 1)])
@test_throws ArgumentError HMM(
[0.8 0.1 0.1; 0.1 0.1 0.8],
[Normal(0, 1), Normal(10, 1)],
)
# Wrong number of distributions
@test_throws ArgumentError HMM([0.8 0.2; 0.1 0.9], [Normal(0, 1), Normal(10, 1), Normal()])
@test_throws ArgumentError HMM(
[0.8 0.2; 0.1 0.9],
[Normal(0, 1), Normal(10, 1), Normal()],
)
# Wrong distributions size
@test_throws ArgumentError HMM([0.8 0.2; 0.1 0.9], [MvNormal(randn(3)), MvNormal(randn(10))])
@test_throws ArgumentError HMM(
[0.8 0.2; 0.1 0.9],
[MvNormal(randn(3)), MvNormal(randn(10))],
)
# Wrong initial state
@test_throws ArgumentError HMM([0.1; 0.1], [0.9 0.1; 0.1 0.9], [Normal(0, 1), Normal(10, 1)])
@test_throws ArgumentError HMM(
[0.1; 0.1],
[0.9 0.1; 0.1 0.9],
[Normal(0, 1), Normal(10, 1)],
)
# Wrong initial state length
@test_throws ArgumentError HMM(
[0.1; 0.1; 0.8],
Expand All @@ -134,7 +150,7 @@ end
@test permutedims(dists2[1]) [1.0 0.0] * (hmm2.A^1000)
@test permutedims(dists2[2]) [0.0 1.0] * (hmm2.A^1000)

@test dists3[1] [15/53, 25/53, 13/53]
@test dists3[1] [15 / 53, 25 / 53, 13 / 53]
end

@testset "Messages (1)" begin
Expand Down Expand Up @@ -288,9 +304,19 @@ end
d = JSON.parse(json(hmm2))
@test_broken from_dict(HMM{Multivariate,Float64}, MvNormal, d) == hmm2

hmm3 = HMM([0.9 0.1; 0.1 0.9], [MixtureModel([Normal(0, 1)]), MixtureModel([Normal(5, 2), Normal(10, 1)], [0.25, 0.75])])
hmm3 = HMM(
[0.9 0.1; 0.1 0.9],
[
MixtureModel([Normal(0, 1)]),
MixtureModel([Normal(5, 2), Normal(10, 1)], [0.25, 0.75]),
],
)
d = JSON.parse(json(hmm3))
@test_broken from_dict(HMM{Univariate,Float64}, MixtureModel{Univariate,Continuous,Normal,Float64}, d) == hmm3
@test_broken from_dict(
HMM{Univariate,Float64},
MixtureModel{Univariate,Continuous,Normal,Float64},
d,
) == hmm3

# MixtureModel <-> HMM (stationnary distribution)
a = [0.4, 0.6]
Expand Down

0 comments on commit c3417a1

Please sign in to comment.