-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ca0222a
commit 3a07bc7
Showing
7 changed files
with
619 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
Ridge and Kernel Ridge | ||
====================== | ||
|
||
## Ridge | ||
|
||
TODO | ||
|
||
## Kernel Ridge | ||
|
||
See [run_kernel_ridge.py](./run_kernel_ridge.py) for an example for Kernel Ridge | ||
|
||
## Contributor | ||
|
||
- [junlulocky](https://github.com/junlulocky) | ||
|
||
## Reference | ||
|
||
- Welling, Max. "Kernel ridge Regression." Max Welling's Classnotes in Machine Learning (2013): 1-3. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import scipy as sp | ||
from numpy.linalg import inv | ||
import numpy as np | ||
from scipy import linalg | ||
|
||
|
||
class KernelRidge(): | ||
""" | ||
Simple implementation of a Kernel Ridge Regression using the | ||
closed form for training. | ||
Doc: https://www.ics.uci.edu/~welling/classnotes/papers_class/Kernel-Ridge.pdf | ||
""" | ||
|
||
def __init__(self, kernel_type='linear', C=1.0, gamma=5.0): | ||
""" | ||
:param kernel_type: Kernel type to use in training. | ||
'linear' use linear kernel function. | ||
'quadratic' use quadratic kernel function. | ||
'gaussian' use gaussian kernel function | ||
:param C: Value of regularization parameter C | ||
:param gamma: parameter for gaussian kernel or Polynomial kernel | ||
""" | ||
self.kernels = { | ||
'linear': self.kernel_linear, | ||
'quadratic': self.kernel_quadratic, | ||
'gaussian': self.kernel_gaussian | ||
} | ||
self.kernel_type = kernel_type | ||
self.kernel = self.kernels[self.kernel_type] | ||
self.C = C | ||
self.gamma = gamma | ||
|
||
# Define kernels | ||
def kernel_linear(self, x1, x2): | ||
return np.dot(x1, x2.T) | ||
|
||
def kernel_quadratic(self, x1, x2): | ||
return (np.dot(x1, x2.T) ** 2) | ||
|
||
def kernel_gaussian(self, x1, x2, gamma=5.0): | ||
gamma = self.gamma | ||
return np.exp(-linalg.norm(x1 - x2) ** 2 / (2 * (gamma ** 2))) | ||
|
||
def compute_kernel_matrix(self, X1, X2): | ||
""" | ||
compute kernel matrix (gram matrix) give two input matrix | ||
""" | ||
|
||
# sample size | ||
n1 = X1.shape[0] | ||
n2 = X2.shape[0] | ||
|
||
# Gram matrix | ||
K = np.zeros((n1, n2)) | ||
for i in range(n1): | ||
for j in range(n2): | ||
K[i, j] = self.kernel(X1[i], X2[j]) | ||
|
||
return K | ||
|
||
|
||
def fit(self, X, y): | ||
""" | ||
training KRR | ||
:param X: training X | ||
:param y: training y | ||
:return: alpha vector, see document TODO | ||
""" | ||
K = self.compute_kernel_matrix(X, X) | ||
|
||
self.alphas = sp.dot(inv(K + self.C * np.eye(np.shape(K)[0])), | ||
y.transpose()) | ||
|
||
return self.alphas | ||
|
||
def predict(self, x_train, x_test): | ||
""" | ||
:param x_train: DxNtr array of Ntr train data points | ||
with D features | ||
:param x_test: DxNte array of Nte test data points | ||
with D features | ||
:return: y_test, D2xNte array | ||
""" | ||
|
||
k = self.compute_kernel_matrix(x_test, x_train) | ||
|
||
y_test = sp.dot(k, self.alphas) | ||
return y_test.transpose() | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import csv, os, sys | ||
import numpy as np | ||
from kernel_ridge import KernelRidge | ||
filepath = os.path.dirname(os.path.abspath(__file__)) | ||
|
||
def readData(filename, header=True): | ||
data, header = [], None | ||
with open(filename, 'rb') as csvfile: | ||
spamreader = csv.reader(csvfile, delimiter=',') | ||
if header: | ||
header = spamreader.next() | ||
for row in spamreader: | ||
data.append(row) | ||
return (np.array(data), np.array(header)) | ||
|
||
def calc_mse(y, y_hat): | ||
return np.nanmean(((y - y_hat) ** 2)) | ||
|
||
def test_main(filename='small_data/iris-virginica.txt', C=1.0, kernel_type='linear'): | ||
# Load data | ||
(data, _) = readData('%s/%s' % (filepath, filename), header=False) | ||
data = data.astype(float) | ||
|
||
# Split data | ||
X, y = data[:,0:-1], data[:,-1].astype(int) | ||
y = y[np.newaxis,:] | ||
print X.shape | ||
print y.shape | ||
|
||
|
||
|
||
# fit our model | ||
model = KernelRidge(kernel_type='gaussian', C=0.1, gamma=5.0) | ||
model.fit(X, y) | ||
y_hat = model.predict(x_train=X, x_test=X) | ||
mse = calc_mse(y, y_hat) # Calculate accuracy | ||
print("mse of KRR:\t%.3f" % (mse)) | ||
|
||
# fit linear model for test | ||
from sklearn import linear_model | ||
ls = linear_model.LinearRegression() | ||
ls.fit(X, y[0,:]) | ||
y_ls = ls.predict(X) | ||
mse = calc_mse(y, y_ls) | ||
print("mse of LS (from sklearn):\t%.3f" % (mse)) | ||
|
||
# fit KRR from sklearn for test | ||
from sklearn.kernel_ridge import KernelRidge as KR2 | ||
kr2 = KR2(kernel='rbf', gamma=5, alpha=10) | ||
kr2.fit(X, y[0, :]) | ||
y_krr = kr2.predict(X) | ||
mse = calc_mse(y, y_krr) | ||
print("mse of KRR (from sklearn):\t%.3f" % (mse)) | ||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
|
||
test_main(filename='./small_data/iris-slwc.txt') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
5.9,3,1 | ||
6.9,3.1,1 | ||
6.6,2.9,1 | ||
4.6,3.2,-1 | ||
6,2.2,1 | ||
4.7,3.2,-1 | ||
6.5,3,1 | ||
5.8,2.7,1 | ||
6.7,3.1,1 | ||
6.7,2.5,1 | ||
5.1,3.7,-1 | ||
5.1,3.8,-1 | ||
5.7,3,1 | ||
6.1,3,1 | ||
4.9,3.1,-1 | ||
5,3.4,-1 | ||
5,3.4,-1 | ||
5.7,2.8,1 | ||
5,3.3,-1 | ||
7.2,3.2,1 | ||
5.9,3,1 | ||
6.5,3,1 | ||
5.7,4.4,-1 | ||
5.5,2.5,1 | ||
4.9,2.5,1 | ||
5,3.5,-1 | ||
5.5,2.3,1 | ||
4.6,3.1,-1 | ||
7.2,3,1 | ||
6.8,3.2,1 | ||
5.4,3.9,-1 | ||
5,3.2,-1 | ||
5.7,2.5,1 | ||
5.8,2.6,1 | ||
5.1,2.5,1 | ||
5.6,2.5,1 | ||
5.8,2.7,1 | ||
5.1,3.8,-1 | ||
6.3,2.3,1 | ||
6.3,2.5,1 | ||
5.6,3,1 | ||
6.1,3,1 | ||
6.8,3,1 | ||
7.3,2.9,1 | ||
5.6,2.7,1 | ||
4.8,3,-1 | ||
7.1,3,1 | ||
5.7,2.6,1 | ||
5.3,3.7,-1 | ||
5.7,3.8,-1 | ||
5.7,2.9,1 | ||
5.6,2.8,1 | ||
4.4,3,-1 | ||
6.3,3.3,1 | ||
5.4,3.4,-1 | ||
6.3,3.4,1 | ||
6.9,3.1,1 | ||
7.7,3,1 | ||
6.1,2.8,1 | ||
5.6,2.9,1 | ||
6.1,2.6,1 | ||
6.4,2.7,1 | ||
5,3.5,-1 | ||
5.1,3.3,-1 | ||
5.6,3,1 | ||
5.4,3,1 | ||
5.8,2.8,1 | ||
4.9,3.1,-1 | ||
4.6,3.6,-1 | ||
5.2,3.4,-1 | ||
7.9,3.8,1 | ||
7.7,2.6,1 | ||
6.1,2.8,1 | ||
5.5,3.5,-1 | ||
4.6,3.4,-1 | ||
4.7,3.2,-1 | ||
4.4,2.9,-1 | ||
6.2,2.8,1 | ||
4.8,3,-1 | ||
6,2.9,1 | ||
6.2,3.4,1 | ||
5,2.3,1 | ||
6.4,3.2,1 | ||
6.3,2.9,1 | ||
6.7,3,1 | ||
5,2,1 | ||
5.9,3.2,1 | ||
6.7,3.3,1 | ||
5.4,3.9,-1 | ||
6.3,2.7,1 | ||
4.8,3.4,-1 | ||
4.4,3.2,-1 | ||
6.4,3.2,1 | ||
6.2,2.2,1 | ||
6,2.2,1 | ||
7.4,2.8,1 | ||
4.9,2.4,1 | ||
7,3.2,1 | ||
5.5,2.4,1 | ||
6.3,3.3,1 | ||
6.8,2.8,1 | ||
6.1,2.9,1 | ||
6.5,3.2,1 | ||
6.7,3.3,1 | ||
6.7,3.1,1 | ||
4.8,3.4,-1 | ||
4.9,3,-1 | ||
6.9,3.2,1 | ||
4.5,2.3,-1 | ||
4.3,3,-1 | ||
5.2,2.7,1 | ||
5,3.6,-1 | ||
6.4,2.9,1 | ||
5.2,3.5,-1 | ||
5.8,2.7,1 | ||
5.5,4.2,-1 | ||
7.6,3,1 | ||
6.3,2.8,1 | ||
6.4,3.1,1 | ||
6.3,2.5,1 | ||
5.8,2.7,1 | ||
5,3,-1 | ||
6.7,3.1,1 | ||
6,2.7,1 | ||
5.1,3.5,-1 | ||
4.8,3.1,-1 | ||
5.7,2.8,1 | ||
5.1,3.8,-1 | ||
6.6,3,1 | ||
6.4,2.8,1 | ||
5.2,4.1,-1 | ||
6.4,2.8,1 | ||
7.7,2.8,1 | ||
5.8,4,-1 | ||
4.9,3.1,-1 | ||
5.4,3.7,-1 | ||
5.1,3.5,-1 | ||
6,3.4,1 | ||
6.5,3,1 | ||
5.5,2.4,1 | ||
7.2,3.6,1 | ||
6.9,3.1,1 | ||
6.2,2.9,1 | ||
6.5,2.8,1 | ||
6,3,1 | ||
5.4,3.4,-1 | ||
5.5,2.6,1 | ||
6.7,3,1 | ||
7.7,3.8,1 | ||
5.1,3.4,-1 |
Oops, something went wrong.