forked from kubeflow/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_test.py
82 lines (65 loc) · 2.49 KB
/
train_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import tempfile
import unittest
import train
class TrainTest(unittest.TestCase):
def test_keras(self):
"""Test training_with_keras."""
output_dir = tempfile.mkdtemp()
output_model = os.path.join(output_dir, "model.h5")
body_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_vecs = os.path.join(output_dir, "title.npy")
body_vecs = os.path.join(output_dir, "body.npy")
this_dir = os.path.dirname(__file__)
args = [
"--sample_size=100",
"--num_epochs=1",
"--input_data=" + os.path.join(this_dir, "test_data",
"github_issues_sample.csv"),
"--output_model=" + output_model,
"--output_body_preprocessor_dpkl="+ body_pp_dpkl,
"--output_title_preprocessor_dpkl="+ title_pp_dpkl,
"--output_train_title_vecs_npy=" + title_vecs,
"--output_train_body_vecs_npy=" + body_vecs,
]
train.main(args)
output_files = [
output_model, body_pp_dpkl, title_pp_dpkl, title_vecs, body_vecs
]
for f in output_files:
self.assertTrue(os.path.exists(f))
# TODO(https://github.com/kubeflow/examples/issues/280)
# TODO(https://github.com/kubeflow/examples/issues/196)
# This test won't work until we fix the code to work using the estimator
# API.
@unittest.skip("skip estimator test")
def test_estimator(self):
"""Test training_with_keras."""
output_dir = tempfile.mkdtemp()
output_model = os.path.join(output_dir, "model")
body_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_pp_dpkl = os.path.join(output_dir, "body_pp.dpkl")
title_vecs = os.path.join(output_dir, "title.npy")
body_vecs = os.path.join(output_dir, "body.npy")
this_dir = os.path.dirname(__file__)
args = [
"--sample_size=100",
"--num_epochs=1",
"--input_data=" + os.path.join(this_dir, "test_data",
"github_issues_sample.csv"),
"--output_model=" + output_model,
"--output_body_preprocessor_dpkl="+ body_pp_dpkl,
"--output_title_preprocessor_dpkl="+ title_pp_dpkl,
"--output_train_title_vecs_npy=" + title_vecs,
"--output_train_body_vecs_npy=" + body_vecs,
"--mode=estimator",
]
train.main(args)
output_files = [
output_model, body_pp_dpkl, title_pp_dpkl, title_vecs, body_vecs
]
for f in output_files:
self.assertTrue(os.path.exists(f))
if __name__ == "__main__":
unittest.main()