forked from kubeflow/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
IssueSummarization.py
35 lines (26 loc) · 1.21 KB
/
IssueSummarization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""Generates predictions using a stored model.
Uses trained model files to generate a prediction.
"""
from __future__ import print_function
import os
import numpy as np
import dill as dpickle
from keras.models import load_model
from seq2seq_utils import Seq2Seq_Inference
class IssueSummarization(object):
def __init__(self):
body_pp_file = os.getenv('BODY_PP_FILE', 'body_pp.dpkl')
print('body_pp file {0}'.format(body_pp_file))
with open(body_pp_file, 'rb') as body_file:
body_pp = dpickle.load(body_file)
title_pp_file = os.getenv('TITLE_PP_FILE', 'title_pp.dpkl')
print('title_pp file {0}'.format(title_pp_file))
with open(title_pp_file, 'rb') as title_file:
title_pp = dpickle.load(title_file)
model_file = os.getenv('MODEL_FILE', 'seq2seq_model_tutorial.h5')
print('model file {0}'.format(model_file))
self.model = Seq2Seq_Inference(encoder_preprocessor=body_pp,
decoder_preprocessor=title_pp,
seq2seq_model=load_model(model_file))
def predict(self, input_text, feature_names): # pylint: disable=unused-argument
return np.asarray([[self.model.generate_issue_title(body[0])[1]] for body in input_text])