diff --git a/src/svm/kernel/rbf.rs b/src/svm/kernel/rbf.rs index 7ecef75..023ca96 100644 --- a/src/svm/kernel/rbf.rs +++ b/src/svm/kernel/rbf.rs @@ -98,8 +98,14 @@ impl KernelSparse for Rbf { a = a_iter.next(); b = b_iter.next(); } - (Some((i_a, _)), Some((i_b, _))) if i_a < i_b => a = a_iter.next(), - (Some((i_a, _)), Some((i_b, _))) if i_a > i_b => b = b_iter.next(), + (Some((i_a, x)), Some((i_b, _))) if i_a < i_b => { + sum += x*x; + a = a_iter.next(); + }, + (Some((i_a, _)), Some((i_b, y))) if i_a > i_b => { + sum += y*y; + b = b_iter.next(); + }, _ => break f64::from((-self.gamma * sum).exp()), } } diff --git a/tests/data_sparse/m_csvm_rbf_prob_out b/tests/data_sparse/m_csvm_rbf_prob_out new file mode 100644 index 0000000..24985c8 --- /dev/null +++ b/tests/data_sparse/m_csvm_rbf_prob_out @@ -0,0 +1,32 @@ +0 +0 +0 +0 +1 +0 +1 +1 +2 +2 +0 +2 +3 +3 +3 +0 +1 +4 +4 +4 +5 +5 +5 +5 +6 +6 +6 +6 +7 +7 +7 +7 diff --git a/tests/data_sparse/m_csvm_rbf_prob_prob_out b/tests/data_sparse/m_csvm_rbf_prob_prob_out new file mode 100644 index 0000000..1f3db5e --- /dev/null +++ b/tests/data_sparse/m_csvm_rbf_prob_prob_out @@ -0,0 +1,33 @@ +labels 0 1 2 3 4 5 6 7 +7 0.119484 0.120108 0.12088 0.123288 0.12371 0.126253 0.130415 0.135863 +7 0.119484 0.120107 0.120879 0.123288 0.12371 0.126252 0.130415 0.135865 +7 0.119484 0.120107 0.12088 0.123288 0.12371 0.126252 0.130415 0.135864 +7 0.119484 0.120108 0.12088 0.123288 0.123709 0.126252 0.130415 0.135864 +7 0.120216 0.120267 0.121451 0.123233 0.123533 0.126403 0.129041 0.135855 +7 0.120037 0.120292 0.120978 0.1234 0.123886 0.126525 0.129511 0.135371 +7 0.120164 0.120175 0.121315 0.123701 0.123774 0.126652 0.128601 0.135617 +7 0.120054 0.119921 0.120982 0.123738 0.123607 0.126231 0.130299 0.135169 +7 0.123219 0.123662 0.119398 0.125328 0.12257 0.126503 0.129271 0.130051 +7 0.121667 0.121901 0.119337 0.123205 0.124482 0.124412 0.129198 0.135798 +7 0.120861 0.121267 0.120471 0.123197 0.123393 0.126492 0.130892 0.133427 +7 0.122582 0.1229 0.1193 0.12324 0.12462 0.125775 0.128051 0.133532 +7 0.125739 0.125779 0.12413 0.115941 0.125509 0.125653 0.128085 0.129163 +7 0.124854 0.125285 0.124085 0.117005 0.122915 0.128395 0.124293 0.13317 +6 0.123263 0.123828 0.122947 0.120213 0.122945 0.128456 0.130038 0.12831 +7 0.121814 0.122241 0.121739 0.120293 0.123986 0.127673 0.130103 0.132151 +6 0.124414 0.125085 0.125445 0.125773 0.12 0.125997 0.129052 0.124234 +7 0.128465 0.128233 0.126326 0.127028 0.113946 0.123 0.121785 0.131217 +7 0.126506 0.126725 0.12409 0.125019 0.112419 0.127678 0.124108 0.133454 +7 0.12597 0.126444 0.125887 0.125661 0.114488 0.126547 0.1227 0.132303 +7 0.124461 0.124986 0.124402 0.12556 0.128264 0.11378 0.125218 0.133328 +1 0.1309 0.131143 0.127721 0.129143 0.121496 0.107805 0.127829 0.123961 +7 0.127302 0.12781 0.123552 0.129208 0.122538 0.109448 0.122563 0.137578 +3 0.127141 0.128102 0.129288 0.129441 0.125219 0.112987 0.124769 0.123053 +7 0.128447 0.128291 0.12964 0.132339 0.121414 0.120332 0.103684 0.135854 +1 0.13453 0.134764 0.132333 0.127564 0.120418 0.12402 0.0995385 0.126832 +0 0.132386 0.130608 0.130778 0.127844 0.126043 0.130172 0.0956114 0.126558 +0 0.130818 0.130488 0.128261 0.130137 0.127045 0.128027 0.0974482 0.127776 +1 0.134602 0.134637 0.133307 0.129923 0.126559 0.130638 0.121036 0.0892972 +1 0.133131 0.133855 0.128253 0.133344 0.123019 0.13126 0.125912 0.0912262 +1 0.1313 0.132089 0.128422 0.12811 0.130371 0.121307 0.12734 0.101062 +1 0.135028 0.137022 0.134388 0.125775 0.127182 0.125489 0.131498 0.0836176 diff --git a/tests/svm_sparse_class.rs b/tests/svm_sparse_class.rs index 257118c..e453f78 100644 --- a/tests/svm_sparse_class.rs +++ b/tests/svm_sparse_class.rs @@ -65,7 +65,7 @@ mod svm_sparse_class { // CSVM test_model!(m_csvm_linear_prob, "m_csvm_linear_prob.libsvm", true, [0, 7], [1, 6]); test_model!(m_csvm_poly_prob, "m_csvm_poly_prob.libsvm", true, [0, 7], [0, 6]); - test_model!(m_csvm_rbf_prob, "m_csvm_rbf_prob.libsvm", true, [7, 7], [1, 0]); + test_model!(m_csvm_rbf_prob, "m_csvm_rbf_prob.libsvm", true, [0, 7], [7, 1]); test_model!(m_csvm_sigmoid_prob, "m_csvm_sigmoid_prob.libsvm", true, [0, 7], [7, 1]); // Temporarily disabled as they trigger ICE in Rust Nightly