-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patht_main.py
45 lines (36 loc) · 1.71 KB
/
t_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import saltclass
import numpy as np
train_X = np.array([[10, 0, 0], [0, 20, 0], [4, 13, 5]])
train_y = np.array([0, 1, 1])
vocab = ['statistics', 'medicine', 'crime']
object_from_df = saltclass.SALT(train_X, train_y, vocabulary=vocab, language='en')
X = np.array([[10, 12, 0], [14, 3, 52]])
object_from_df.enrich(method='kmeans', include_unlabeled=True, unlabeled_matrix=X)
object_from_df.enrich(method='kmeans', include_unlabeled=True, unlabeled_dir='D:/Data/unlabeled/')
object_from_file = saltclass.SALT.data_from_dir(train_dir='D:/data/train2/', language='en')
object_from_file.enrich(method='lda')
object_from_file.train(classifier='svm')
object_from_file.print_info()
prediction = object_from_file.predict(data_file='second_test.txt')
print(object_from_file.vocabulary)
print(object_from_file.newdata)
print([k for (k, v) in object_from_file.vocabulary.items() if object_from_file.newdata[0][v] != 0])
print(prediction)
# stc_object = STClassifier(train_X, train_y, vocabulary=['statistics', 'medicine', 'crime'], language='en')
# stc_object.kmeans_enrich(num_clusters=2)
# stc_object.train(classifier='SVM')
# stc_object.print_info()
# prediction = stc_object.predict(data_file='first_test.txt')
# print(stc_object.newdata)
# print(prediction)
# object_from_file = stclassifier.STClassifier.from_data_dir(train_dir='D:/train/', language='en')
# object_from_file.print_info()
# print(object_from_file.vocabulary)
# object_from_file.kmeans_enrich(num_clusters=2)
# print(object_from_file.X)
# object_from_file.train(classifier='SVM', gamma=3)
# prediction = object_from_file.predict(data_file='first_test.txt')
# print(object_from_file.newdata)
# print(prediction)
# print(STClassifier.__init__.__doc__)
# print(help(STClassifier.__init__))