Skip to content

Commit

Permalink
Added silhouette plot
Browse files Browse the repository at this point in the history
  • Loading branch information
tobydriscoll committed Nov 18, 2019
1 parent cbcc6a4 commit f430829
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.12.0"
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DataValues = "e7dc6d0d-1eca-5fa6-8ad6-5aecde8b7ea5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
Expand Down
3 changes: 3 additions & 0 deletions src/StatsPlots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ using Widgets, Observables
import Observables: AbstractObservable, @map, observe
import Widgets: @nodeps
import DataStructures: OrderedDict
using Distances: pairwise, Euclidean
import Clustering: Hclust, nnodes
using Clustering: ClusteringResult, silhouettes, assignments, counts
using Interpolations
import MultivariateStats: MDS, eigvals, projection, principalvars,
principalratio, transform
Expand All @@ -40,5 +42,6 @@ include("bar.jl")
include("dendrogram.jl")
include("andrews.jl")
include("ordinations.jl")
include("silhouetteplot.jl")

end # module
83 changes: 83 additions & 0 deletions src/silhouetteplot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
silhouetteplot(C,X[,D];...)
Make a silhouette plot to assess the quality of a clustering. `C` must be a `ClusteringResult` (see the `Clustering` package), and `X` is a matrix in which each column represents a data point. If supplied, `D` should be a distance matrix (as in `Distances`); otherwise, pairwise Euclidean distances are used.
Each data point has a silhouette score between -1 and 1 indicating how unambiguously the point belongs to its assigned cluster. These are sorted within each cluster and portrayed using horizontal bars. Also shown is a dashed line at the average score. Typically a high-quality clustering has significant numbers of bars within each cluster that cross the line, and few negative scores overall.
See also: [`Clustering.silhouettes`](@ref), [`Distances`](@ref).
# Examples
```
using Clustering, Distances, Plots
# random dataset with 3-ish clusters in 5 dimensions
X = hcat([rand(5,1) .+ 0.2*randn(5, 200) for _=1:3]...)
D = pairwise(Euclidean(),X,dims=2)
R = kmeans(D, 3; maxiter=200, display=:iter)
silhouetteplot(R,X,D)
```
"""
silhouetteplot

@userplot SilhouettePlot
@recipe function f(h::SilhouettePlot)#R::ClusteringResult,X::AbstractArray,D::AbstractMatrix=[];distance=Euclidean())
narg = length(h.args)
@assert narg > 1 "At least two arguments are required."
R = h.args[1]
@assert R isa ClusteringResult "First argument must be a ClusteringResult."
X = h.args[2]
@assert X isa AbstractArray "Second argument must be an array."
if narg > 2
D = h.args[3]
@assert D isa AbstractMatrix "Third argument must be a distance matrix."
else
D = pairwise(Euclidean(),X,dims=2)
end

a = assignments(R) # assignments to clusters
c = counts(R) # cluster sizes
k = length(c) # number of clusters
n = sum(c) # number of points overall

s = silhouettes(R,D)

# Settings for the axes
legend --> false
yflip := true
xlims := [min(-0.1,minimum(s)),1]
# y ticks used to show cluster boundaries, and labels to show the sizes
yticks := cumsum([0;c]),["0",["+$z" for z in c]...]

# Generate the polygons for each cluster.
offset = 0;
plt = plot([],label="")
for i in 1:k
idx = (a.==i) # members of cluster i
si = sort(s[idx],rev=true)
@series begin
linealpha --> 0
seriestype := :shape
label := "$i"
x = [0;repeat(si,inner=(2));0]
y = offset .+ repeat(0:c[i],inner=(2))
x,y
end
# text label to the left of the bars
@series begin
linealpha := 0
series_annotations := [ Plots.text("$i",:center,:middle,9) ]
[-0.04], [offset+c[i]/2]
end
offset += c[i];
end

# Dashed line for overall average.
savg = sum(s)/n
@series begin
linecolor := :black
linestyle := :dash
label := ""
[savg,savg], [0,n]
end
end

0 comments on commit f430829

Please sign in to comment.