-
Notifications
You must be signed in to change notification settings - Fork 9
/
ensemble.py
60 lines (45 loc) · 1.9 KB
/
ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pickle
class AdaBoostClassifier:
'''A simple AdaBoost Classifier.'''
def __init__(self, weak_classifier, n_weakers_limit):
'''Initialize AdaBoostClassifier
Args:
weak_classifier: The class of weak classifier, which is recommend to be sklearn.tree.DecisionTreeClassifier.
n_weakers_limit: The maximum number of weak classifier the model can use.
'''
pass
def is_good_enough(self):
'''Optional'''
pass
def fit(self,X,y):
'''Build a boosted classifier from the training set (X, y).
Args:
X: An ndarray indicating the samples to be trained, which shape should be (n_samples,n_features).
y: An ndarray indicating the ground-truth labels correspond to X, which shape should be (n_samples,1).
'''
pass
def predict_scores(self, X):
'''Calculate the weighted sum score of the whole base classifiers for given samples.
Args:
X: An ndarray indicating the samples to be predicted, which shape should be (n_samples,n_features).
Returns:
An one-dimension ndarray indicating the scores of differnt samples, which shape should be (n_samples,1).
'''
pass
def predict(self, X, threshold=0):
'''Predict the catagories for geven samples.
Args:
X: An ndarray indicating the samples to be predicted, which shape should be (n_samples,n_features).
threshold: The demarcation number of deviding the samples into two parts.
Returns:
An ndarray consists of predicted labels, which shape should be (n_samples,1).
'''
pass
@staticmethod
def save(model, filename):
with open(filename, "wb") as f:
pickle.dump(model, f)
@staticmethod
def load(filename):
with open(filename, "rb") as f:
return pickle.load(f)