-
Notifications
You must be signed in to change notification settings - Fork 195
/
SVM_setup.m
103 lines (94 loc) · 3.84 KB
/
SVM_setup.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
function [train_data,train_class,test_data,test_class,fig_handle] = SVM_setup(design_opt,train_N,test_N,rand_seed)
% function [train_data,test_data,train_class,test_class,fig_handle] = SVM_setup(design_opt,train_N,test_N)
% This function sets up the SVM examples
% Inputs: design_opt - 1 for linearly separable, 2 for nonseparable
% train_N - number of training points per population
% test_N - number of test points per population
% rand_seed - <optional> random seed (default=0)
% Outputs: train_data - the training set
% train_class - the classifications at the training points
% test_data - the test set
% test_class - the classifications at the test points
% fig_handle - <optional> return the handle of the figure
% of a scatter plot of the data
%
% design_opt = 1 has class=1 centered at (1,0)
% class=-1 centered at (0,1)
% design_opt = 2 has class=1 centered at (1,0), (0,1), (2,1)
% class=-1 centered at (0,0), (1,1), (2,0)
global GAUSSQR_PARAMETERS
if ~isstruct(GAUSSQR_PARAMETERS)
error('GAUSSQR_PARAMETERS does not exist ... did you forget to call rbfsetup?')
end
statsOn = GAUSSQR_PARAMETERS.STATISTICS_TOOLBOX_AVAILABLE;
if not(statsOn)
error('Sorry, but for the moment you need the stats package to run this')
end
if nargin==3
rand_seed = 0;
elseif nargin~=4
error('Unacceptable arguments, nargin=%d',nargin)
end
plot_on = 0;
if nargout==5
plot_on = 1;
end
% This allows us to control the results
% Some older versions of Matlab are handled with this block
if exist('rng','builtin')
rng(rand_seed);
else
rand('state',rand_seed);
randn('state',rand_seed);
if not(exist('randi','builtin'))
randi = @(n,r,c) ceil(n*rand(r,c));
end
end
% This shouldn't be necessary, but seems to be.
randi = @(n,r,c) ceil(n*rand(r,c));
% Set up the data points based on what the user requests
switch design_opt
case 1
grnmean = [1,0];
redmean = [0,1];
covmat = eye(2);
case 2
grnmean = [1,0;0,1;2,1];
redmean = [0,0;1,1;2,0];
covmat = .5*eye(2);
otherwise
error('design_opt must be either 1 or 2')
end
grnmean_N = size(grnmean,1);
redmean_N = size(redmean,1);
% How much randomness do we want in our training set
buffer = .2;
% Generate some manufactured data and attempt to classify it
% The data will be generated by normal distributions with different means
grnpop = mvnrnd(grnmean(randi(grnmean_N,test_N,1),:),repmat(covmat,[1,1,test_N]),test_N);
redpop = mvnrnd(redmean(randi(redmean_N,test_N,1),:),repmat(covmat,[1,1,test_N]),test_N);
% Generate a training set from which to learn the classifier
grnpts = mvnrnd(grnpop(randi(test_N,train_N,1),:),repmat(covmat*buffer,[1,1,train_N]),train_N);
redpts = mvnrnd(redpop(randi(test_N,train_N,1),:),repmat(covmat*buffer,[1,1,train_N]),train_N);
% Create a vector of data and associated classifications
% Green label 1, red label -1
train_data = [grnpts;redpts];
train_class = ones(2*train_N,1);
train_class(train_N+1:2*train_N) = -1;
test_data = [grnpop;redpop];
test_class = ones(2*test_N,1);
test_class(test_N+1:2*test_N) = -1;
% Scatter plot of the input data
if plot_on
fig_handle = figure;
hold on
plot(grnpop(:,1),grnpop(:,2),'g+','markersize',12)
plot(redpop(:,1),redpop(:,2),'rx','markersize',12)
plot(grnpop(:,1),grnpop(:,2),'bs','markersize',12)
plot(redpop(:,1),redpop(:,2),'bo','markersize',12)
plot(grnmean(:,1),grnmean(:,2),'gs','markersize',12,'MarkerFaceColor','g')
plot(redmean(:,1),redmean(:,2),'ro','markersize',12,'MarkerFaceColor','r')
plot(grnpts(:,1),grnpts(:,2),'g+','markersize',7)
plot(redpts(:,1),redpts(:,2),'rx','markersize',7)
hold off
end