diff --git a/tm_admin/users/users.py b/tm_admin/users/users.py index 3884aafa..9fad6043 100755 --- a/tm_admin/users/users.py +++ b/tm_admin/users/users.py @@ -28,6 +28,8 @@ from dateutil.parser import parse import tm_admin.types_tm from tm_admin.types_tm import Userrole, Mappinglevel +import concurrent.futures +from cpuinfo import get_cpu_info from tm_admin.dbsupport import DBSupport from tm_admin.users.users_class import UsersTable @@ -37,6 +39,29 @@ # Instantiate logger log = logging.getLogger(__name__) +# The number of threads is based on the CPU cores +info = get_cpu_info() +cores = info["count"] + +def importThread( + data: list, + db: PostgresClient, +): + """Thread to handle importing + + Args: + data (list): The list of records to import + db (PostgresClient): A database connection + """ + for record in data: + sql = f" UPDATE users SET licenses = ARRAY[{record[0]['license']}] WHERE id={record[0]['user']}" + # print(sql) + try: + result = db.dbcursor.execute(f"{sql};") + except: + return False + + return True class UsersDB(DBSupport): def __init__(self, @@ -77,43 +102,56 @@ def mergeInterests(self): data[entry['user_id']] = list() data[entry['user_id']].append(entry['interest_id']) - for uid, value in data.items(): - sql = f" UPDATE users SET interests = ARRAY{str(value)} WHERE id={uid}" - print(sql) - try: - result = self.pg.dbcursor.execute(f"{sql};") - except: - return False + for uid, value in data.items(): + sql = f" UPDATE users SET interests = ARRAY{str(value)} WHERE id={uid}" + print(sql) + try: + result = self.pg.dbcursor.execute(f"{sql};") + except: + return False return True def mergeLicenses(self): + """Merge data from the TM user_licenses table into TM Admin.""" table = 'user_licenses' - # FIXME: this shouldn't be hardcoded! - pg = PostgresClient('localhost/tm4') sql = f"SELECT row_to_json({table}) as row FROM {table}" - # print(sql) + # One database connection per thread + tmpg = list() + for i in range(0, cores + 1): + tmpg.append(PostgresClient('localhost/tm_admin')) + + # just one thread to read the data + pg = PostgresClient('localhost/tm4') try: result = pg.dbcursor.execute(sql) except: log.error(f"Couldn't execute query! {sql}") return False - result = pg.dbcursor.fetchall() - - for record in result: - sql = f" UPDATE users SET licenses = ARRAY[{record[0]['license']}] WHERE id={record[0]['user']}" - print(sql) - try: - result = self.pg.dbcursor.execute(f"{sql};") - except: - return False + data = pg.dbcursor.fetchall() + entries = len(data) + log.debug(f"There are {entries} entries in {table}") + chunk = round(entries / cores) + + # if True: + # importThread(data, tmpg[0]) + index = 0 + with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor: + block = 0 + while block <= entries: + log.debug("Dispatching Block %d:%d" % (block, block + chunk)) + result = executor.submit(importThread, data[block : block + chunk], tmpg[index]) + block += chunk + index += 1 + executor.shutdown() return True def mergeFavorites(self): table = 'project_favorites' - pg = PostgresClient(inuri) + # FIXME: this shouldn't be hardcoded! + pg = PostgresClient('localhost/tm4') sql = f"SELECT row_to_json({table}) as row FROM {table}" # print(sql) try: @@ -122,6 +160,24 @@ def mergeFavorites(self): log.error(f"Couldn't execute query! {sql}") return False + result = pg.dbcursor.fetchall() + data = dict() + for record in result: + entry = record[0] # there's only one item in the input data + if entry['user_id'] not in data: + data[entry['user_id']] = list() + data[entry['user_id']].append(entry['project_id']) + + for uid, value in data.items(): + sql = f" UPDATE users SET favorite_projects = ARRAY{str(value)} WHERE id={uid}" + print(sql) + try: + result = self.pg.dbcursor.execute(f"{sql};") + except: + return False + + return True + # These are just convience wrappers to support the REST API. def updateRole(self, id: int, @@ -208,12 +264,16 @@ def main(): user = UsersDB(args.uri) - if user.mergeInterests(): - log.info("UserDB.mergeInterests worked!") + # These may take a long time to complete + # if user.mergeInterests(): + # log.info("UserDB.mergeInterests worked!") if user.mergeLicenses(): log.info("UserDB.mergeLicenses worked!") + # if user.mergeFavorites(): + # log.info("UserDB.mergeFavorites worked!") + # user.resetSequence() # all = user.getAll() # # Don't pass id, let postgres auto increment