Skip to content

Commit

Permalink
Merge pull request #923 from Roeya/master
Browse files Browse the repository at this point in the history
improve show of confusion matrix by finding optimal column width (cw)
  • Loading branch information
ablaom authored Aug 15, 2023
2 parents 7cea929 + 8a182dc commit f808e98
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 21 deletions.
23 changes: 13 additions & 10 deletions src/measures/confusion_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,27 +165,30 @@ splitw(w::Int) = (sp1 = div(w, 2); sp2 = w - sp1; (sp1, sp2))
function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrixObject{C}
) where C
width = displaysize(stream)[2]
cw = 13
mincw = ceil(Int, 12/C)
cw = max(length(string(maximum(cm.mat))),maximum(length.(cm.labels)),mincw)
firstcw = max(length(string(maximum(cm.mat))),maximum(length.(cm.labels)),9)
textlim = 9
totalwidth = cw * (C+1) + C + 2
totalwidth = firstcw + cw * C + C + 2
width < totalwidth && (show(stream, m, cm.mat); return)

iob = IOBuffer()
wline = s -> write(iob, s * "\n")
splitcw = s -> (w = cw - length(s); splitw(w))
splitfirstcw = s -> (w = firstcw - length(s); splitw(w))
cropw = s -> length(s) > textlim ? s[1:prevind(s, textlim)] * "" : s

# 1.a top box
" "^(cw+1) * "" * ""^((cw + 1) * C - 1) * "" |> wline
" "^(firstcw+1) * "" * ""^((cw + 1) * C - 1) * "" |> wline
gt = "Ground Truth"
w = (cw + 1) * C - 1 - length(gt)
sp1, sp2 = splitw(w)
" "^(cw+1) * "" * " "^sp1 * gt * " "^sp2 * "" |> wline
" "^(firstcw+1) * "" * " "^sp1 * gt * " "^sp2 * "" |> wline
# 1.b separator
"" * ""^cw * "" * (""^cw * "")^(C-1) * ""^cw * "" |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * ""^cw * "" |> wline
# 2.a description line
pr = "Predicted"
sp1, sp2 = splitcw(pr)
sp1, sp2 = splitfirstcw(pr)
partial = "" * " "^sp1 * pr * " "^sp2 * ""
for c in 1:C
# max = 10
Expand All @@ -195,12 +198,12 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrixObject{C}
end
partial |> wline
# 2.b separating line
"" * ""^cw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
# 2.c line by line
for c in 1:C
# line
s = cm.labels[c] |> cropw
sp1, sp2 = splitcw(s)
sp1, sp2 = splitfirstcw(s)
partial = "" * " "^sp1 * s * " "^sp2 * ""
for r in 1:C
e = string(cm[c, r])
Expand All @@ -210,11 +213,11 @@ function Base.show(stream::IO, m::MIME"text/plain", cm::ConfusionMatrixObject{C}
partial |> wline
# separator
if c < C
"" * ""^cw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
end
end
# 2.d final line
"" * ""^cw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
"" * ""^firstcw * "" * (""^cw * "")^(C-1) * (""^cw * "") |> wline
write(stream, take!(iob))
end

Expand Down
22 changes: 11 additions & 11 deletions test/measures/confusion_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,17 @@ end
Base.show(iob, MIME("text/plain"), MLJBase._confmat(ŷ, y))
siob = String(take!(iob))
@test strip(siob) == strip("""
┌─────────────────────────────────────────┐
Ground Truth
┌─────────────┼─────────────┬─────────────┬─────────────┤
Predicted1 2 │ 3
├─────────────┼─────────────┼─────────────┼─────────────┤
1 3 0 │ 0
├─────────────┼─────────────┼─────────────┼─────────────┤
2 0 3 │ 0
├─────────────┼─────────────┼─────────────┼─────────────┤
3 0 0 │ 3
└─────────────┴─────────────┴─────────────┴─────────────┘""")
──────────────┐
Ground Truth │
┌─────────┼────┬────────┤
Predicted1 2 │ 3
├─────────┼────┼────────┤
│ 1 3 0 │ 0
├─────────┼────┼────────┤
│ 2 0 3 │ 0
├─────────┼────┼────────┤
│ 3 0 0 │ 3
└─────────┴────┴────────┘""")
end

@testset "ConfusionMatrix measure" begin
Expand Down

0 comments on commit f808e98

Please sign in to comment.