forked from amauboussin/arxiv-twitterbot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
29 lines (22 loc) · 1011 Bytes
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from keras.models import load_model
import numpy as np
from constants import DICT_TRANSFORM_PATH, MODEL_PATH
from conv_net import load_doc_to_word_indices
from preprocessing import arxiv_df_to_list_of_dicts, parse_content_serial
def add_conv_predictions_to_date(paper_df, index):
"""Predict whether or not the papers published on a given day will be tweeted"""
papers_to_predict = paper_df.loc[index]
model = load_conv_net()
data = arxiv_df_to_list_of_dicts(papers_to_predict)
data = parse_content_serial(data)
paper_df['prediction'] = np.nan
paper_df.loc[index, 'prediction'] = model([r['content'] for r in data])[:, 1]
return paper_df
def load_conv_net():
"""Return a function from list of dicts to probability miles would tweet it"""
model = load_model(MODEL_PATH)
dict_transform = load_doc_to_word_indices(DICT_TRANSFORM_PATH)
def predict_proba(docs):
x = dict_transform.transform(docs)
return model.predict(x)
return predict_proba