Skip to content

Commit

Permalink
Fix to sparse rbf-kernel. Tests adjustet so they match same results a…
Browse files Browse the repository at this point in the history
…s libsvm
  • Loading branch information
timkaas committed Nov 15, 2024
1 parent cd1a621 commit aab20f3
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 3 deletions.
10 changes: 8 additions & 2 deletions src/svm/kernel/rbf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ impl KernelSparse for Rbf {
a = a_iter.next();
b = b_iter.next();
}
(Some((i_a, _)), Some((i_b, _))) if i_a < i_b => a = a_iter.next(),
(Some((i_a, _)), Some((i_b, _))) if i_a > i_b => b = b_iter.next(),
(Some((i_a, x)), Some((i_b, _))) if i_a < i_b => {
sum += x*x;
a = a_iter.next();
},
(Some((i_a, _)), Some((i_b, y))) if i_a > i_b => {
sum += y*y;
b = b_iter.next();
},
_ => break f64::from((-self.gamma * sum).exp()),
}
}
Expand Down
32 changes: 32 additions & 0 deletions tests/data_sparse/m_csvm_rbf_prob_out
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
0
0
0
0
1
0
1
1
2
2
0
2
3
3
3
0
1
4
4
4
5
5
5
5
6
6
6
6
7
7
7
7
33 changes: 33 additions & 0 deletions tests/data_sparse/m_csvm_rbf_prob_prob_out
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
labels 0 1 2 3 4 5 6 7
7 0.119484 0.120108 0.12088 0.123288 0.12371 0.126253 0.130415 0.135863
7 0.119484 0.120107 0.120879 0.123288 0.12371 0.126252 0.130415 0.135865
7 0.119484 0.120107 0.12088 0.123288 0.12371 0.126252 0.130415 0.135864
7 0.119484 0.120108 0.12088 0.123288 0.123709 0.126252 0.130415 0.135864
7 0.120216 0.120267 0.121451 0.123233 0.123533 0.126403 0.129041 0.135855
7 0.120037 0.120292 0.120978 0.1234 0.123886 0.126525 0.129511 0.135371
7 0.120164 0.120175 0.121315 0.123701 0.123774 0.126652 0.128601 0.135617
7 0.120054 0.119921 0.120982 0.123738 0.123607 0.126231 0.130299 0.135169
7 0.123219 0.123662 0.119398 0.125328 0.12257 0.126503 0.129271 0.130051
7 0.121667 0.121901 0.119337 0.123205 0.124482 0.124412 0.129198 0.135798
7 0.120861 0.121267 0.120471 0.123197 0.123393 0.126492 0.130892 0.133427
7 0.122582 0.1229 0.1193 0.12324 0.12462 0.125775 0.128051 0.133532
7 0.125739 0.125779 0.12413 0.115941 0.125509 0.125653 0.128085 0.129163
7 0.124854 0.125285 0.124085 0.117005 0.122915 0.128395 0.124293 0.13317
6 0.123263 0.123828 0.122947 0.120213 0.122945 0.128456 0.130038 0.12831
7 0.121814 0.122241 0.121739 0.120293 0.123986 0.127673 0.130103 0.132151
6 0.124414 0.125085 0.125445 0.125773 0.12 0.125997 0.129052 0.124234
7 0.128465 0.128233 0.126326 0.127028 0.113946 0.123 0.121785 0.131217
7 0.126506 0.126725 0.12409 0.125019 0.112419 0.127678 0.124108 0.133454
7 0.12597 0.126444 0.125887 0.125661 0.114488 0.126547 0.1227 0.132303
7 0.124461 0.124986 0.124402 0.12556 0.128264 0.11378 0.125218 0.133328
1 0.1309 0.131143 0.127721 0.129143 0.121496 0.107805 0.127829 0.123961
7 0.127302 0.12781 0.123552 0.129208 0.122538 0.109448 0.122563 0.137578
3 0.127141 0.128102 0.129288 0.129441 0.125219 0.112987 0.124769 0.123053
7 0.128447 0.128291 0.12964 0.132339 0.121414 0.120332 0.103684 0.135854
1 0.13453 0.134764 0.132333 0.127564 0.120418 0.12402 0.0995385 0.126832
0 0.132386 0.130608 0.130778 0.127844 0.126043 0.130172 0.0956114 0.126558
0 0.130818 0.130488 0.128261 0.130137 0.127045 0.128027 0.0974482 0.127776
1 0.134602 0.134637 0.133307 0.129923 0.126559 0.130638 0.121036 0.0892972
1 0.133131 0.133855 0.128253 0.133344 0.123019 0.13126 0.125912 0.0912262
1 0.1313 0.132089 0.128422 0.12811 0.130371 0.121307 0.12734 0.101062
1 0.135028 0.137022 0.134388 0.125775 0.127182 0.125489 0.131498 0.0836176
2 changes: 1 addition & 1 deletion tests/svm_sparse_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ mod svm_sparse_class {
// CSVM
test_model!(m_csvm_linear_prob, "m_csvm_linear_prob.libsvm", true, [0, 7], [1, 6]);
test_model!(m_csvm_poly_prob, "m_csvm_poly_prob.libsvm", true, [0, 7], [0, 6]);
test_model!(m_csvm_rbf_prob, "m_csvm_rbf_prob.libsvm", true, [7, 7], [1, 0]);
test_model!(m_csvm_rbf_prob, "m_csvm_rbf_prob.libsvm", true, [0, 7], [7, 1]);
test_model!(m_csvm_sigmoid_prob, "m_csvm_sigmoid_prob.libsvm", true, [0, 7], [7, 1]);

// Temporarily disabled as they trigger ICE in Rust Nightly
Expand Down

0 comments on commit aab20f3

Please sign in to comment.