forked from radanalyticsio/jiminy-modeler
-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
executable file
·232 lines (204 loc) · 8.97 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python
"""Main app file for modeler."""
import argparse
import os
import os.path
import sys
import time
import psycopg2
import pyspark
import logger
import modeller
import storage
def get_arg(env, default):
"""Extract command line args, else use defaults if none given."""
return os.getenv(env) if os.getenv(env, '') is not '' else default
def make_connection(host='127.0.0.1', port=5432, user='postgres',
password='postgres', dbname='postgres'):
"""Connect to a postgresql db."""
return psycopg2.connect(host=host, port=port, user=user,
password=password, dbname=dbname)
def build_connection(args):
"""Make the db connection with an args object."""
conn = make_connection(host=args.host,
port=args.port,
user=args.user,
password=args.password,
dbname=args.dbname)
return conn
def parse_args(parser):
"""Parsing command line args."""
args = parser.parse_args()
args.host = get_arg('DB_HOST', args.host)
args.port = get_arg('DB_PORT', args.port)
args.user = get_arg('DB_USER', args.user)
args.password = get_arg('DB_PASSWORD', args.password)
args.dbname = get_arg('DB_DBNAME', args.dbname)
args.mongoURI = get_arg('MONGO_URI', args.mongoURI)
args.rankval = check_positive_integer(get_arg('RANK_VAL', args.rankval))
args.itsval = check_iterations_value(get_arg('ITS_VAL', args.itsval))
args.lambdaval = check_lambda_value(get_arg('LAMBDA_VAL', args.lambdaval))
return args
def check_positive_integer(string):
"""Checking that values are poitive integers."""
fval = float(string)
ival = int(fval)
if ival != fval or ival <= 0:
msg = "%r is not a positive integer" % string
raise argparse.ArgumentTypeError(msg)
return ival
def check_iterations_value(string):
"""Checking that the value is a positive integer. Warn if too large."""
pval = check_positive_integer(string)
if pval > 10:
logger.get_logger().warning("Large iterations causes slow model build")
return pval
def check_lambda_value(string):
"""Check that lambda is positive. Warn if too large."""
fval = float(string)
if fval <= 0:
msg = "%r is not positive" % string
raise argparse.ArgumentTypeError(msg)
if fval > 1:
logger.get_logger().warning("Optimal lambda is commonly in (0,1)")
return fval
def main(arguments):
"""Begin running the the modeller."""
loggers = logger.get_logger()
# set up the spark configuration
loggers.debug("Connecting to Spark")
conf = (pyspark.SparkConf().setAppName("JiminyModeler")
.set('spark.executor.memory', '1G')
.set('spark.driver.memory', '1G')
.set('spark.driver.maxResultSize', '1G'))
# get the spark context
spark = pyspark.sql.SparkSession.builder.config(conf=conf).getOrCreate()
sc = spark.sparkContext
# set up SQL connection
try:
con = build_connection(arguments)
except IOError:
loggers.error("Could not connect to data store")
sys.exit(1)
# fetch the data from the db
cursor = con.cursor()
cursor.execute("SELECT * FROM ratings")
ratings = cursor.fetchall()
loggers.info("Fetched data from table")
# create an RDD of the ratings data
ratingsRDD = sc.parallelize(ratings)
# getting the largest timestamp. We use this to determine new entries later
max_timestamp = ratingsRDD.map(lambda x: x[4]).max()
# remove the final column which contains the time stamps
ratingsRDD = ratingsRDD.map(lambda x: (x[1], x[2], x[3]))
# split the RDD into 3 sections: training, validation and testing
estimator = modeller.Estimator(ratingsRDD)
if get_arg('DISABLE_FAST_TRAIN', args.slowtrain) is True:
loggers.warn("Any ALS parameters given on the command line will not"
" be used in when fast train is disabled.")
# basic parameter selection
loggers.info('Using slow training method')
parameters = estimator.run(ranks=[2, 4, 6, 8],
lambdas=[0.01, 0.05, 0.09, 0.13],
iterations=[2])
else:
# override basic parameters for faster testing
loggers.info('Using fast training method')
parameters = {'rank': arguments.rankval,
'lambda': arguments.lambdaval,
'iteration': arguments.itsval}
# train the model
model = modeller.Trainer(data=ratingsRDD,
rank=parameters['rank'],
iterations=parameters['iteration'],
lambda_=parameters['lambda'],
seed=42).train()
loggers.info('Model has been trained')
# write the model to model store
model_version = 1
writer = storage.MongoDBModelWriter(sc=sc, uri=arguments.mongoURI)
writer.write(model=model, version=1)
loggers.info('Model version 1 written to model store')
while True:
# this loop should be at the heart of this application, it will
# continually loop until killed by the orchestration engine.
# in this loop it should generally do the following:
# 1. check to see if it should create a new model
# 2. if yes, create a new model. if no, continue looping
# (perhaps with a delay)
# 3. store new model
# check to see if new model should be created
# select the maximum time stamp from the ratings database
cursor.execute(
"SELECT timestamp FROM ratings ORDER BY timestamp DESC LIMIT 1;"
)
checking_max_timestamp = cursor.fetchone()[0]
loggers.info(
"The latest timestamp = {}". format(checking_max_timestamp))
if checking_max_timestamp > max_timestamp:
# build a new model
# first, fetch all new ratings
cursor.execute(
"SELECT * FROM ratings WHERE (timestamp > %s);",
(max_timestamp,))
new_ratings = cursor.fetchall()
max_timestamp = checking_max_timestamp
new_ratingsRDD = sc.parallelize(new_ratings)
new_ratingsRDD = new_ratingsRDD.map(lambda x: (x[1], x[2], x[3]))
ratingsRDD = ratingsRDD.union(new_ratingsRDD)
model_version += 1
loggers.info("Training model, version={}".format(model_version))
# train the model
model = modeller.Trainer(data=ratingsRDD,
rank=parameters['rank'],
iterations=parameters['iteration'],
lambda_=parameters['lambda'],
seed=42).train()
loggers.info("Model has been trained.")
writer.write(model=model, version=model_version)
loggers.info(
"Model version %f written to model store." % (model_version))
else:
# sleep for 2 minutes
loggers.info("sleeping for 120 seconds")
time.sleep(120)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='load data from postgresql db')
parser.add_argument(
'--host', default='127.0.0.1',
help='the postgresql host address (default:127.0.0.1).'
'env variable: DB_HOST')
parser.add_argument(
'--dbname', default='postgres',
help='the database name to load with data. env variable: DB_DBNAME')
parser.add_argument(
'--port', default=5432, help='the postgresql port (default: 5432). '
'env variable: DB_PORT')
parser.add_argument(
'--user', default='postgres',
help='the user for the postgresql database (default: postgres). '
'env variable: DB_USER')
parser.add_argument(
'--password', default='postgres',
help='the password for the postgresql user (default: postgres). '
'env variable: DB_PASSWORD')
parser.add_argument(
'--mongo-uri', default='mongodb://localhost:27017', dest='mongoURI',
help='the mongodb URI (default:mongodb://localhost:27017).'
'env variable:MONGO_URI')
parser.add_argument(
'--disable-fast-train', dest='slowtrain', action='store_true',
help='disable the faster training method, warning this may slow '
'down quite a bit for the first run.')
parser.add_argument(
'--rankval', default=6, type=check_positive_integer, help='fixing '
'the rank parameter of ALS. (default = 6). env variable:RANK_VAL')
parser.add_argument(
'--itsval', default=2, type=check_iterations_value, help='fix ALS '
'iterations parameter (default = 2). env variable:ITS_VAL')
parser.add_argument(
'--lambdaval', default=0.01, type=check_lambda_value, help='fix ALS '
'lambda parameter (default: 0.01). env variable:LAMBDA_VAL')
args = parse_args(parser)
main(arguments=args)