From cb84801bb86bf1d2dbe723e60ef5b18ead59420a Mon Sep 17 00:00:00 2001 From: Rob Savoye Date: Mon, 12 Feb 2024 16:46:52 -0700 Subject: [PATCH] fix: Refactor merging in interests to use asyncpg --- tm_admin/users/users.py | 391 +++++++++++++++++++++------------------- 1 file changed, 209 insertions(+), 182 deletions(-) diff --git a/tm_admin/users/users.py b/tm_admin/users/users.py index 29a3cbb0..ea607d22 100755 --- a/tm_admin/users/users.py +++ b/tm_admin/users/users.py @@ -33,9 +33,11 @@ from atpbar import atpbar from tm_admin.dbsupport import DBSupport from tm_admin.users.users_class import UsersTable -from osm_rawdata.postgres import uriParser, PostgresClient +from osm_rawdata.pgasync import PostgresClient from tm_admin.types_tm import Userrole from tqdm import tqdm +import tqdm.asyncio +import asyncio from codetiming import Timer import threading @@ -46,37 +48,46 @@ info = get_cpu_info() cores = info["count"] * 2 -def licensesThread( - data: list, - db: PostgresClient, -): - """Thread to handle importing - - Args: - data (list): The list of records to import - db (PostgresClient): A database connection - """ - array = "licenses" - column = "license" - - pbar = tqdm(data) - for record in pbar: - uid = record[0]['user'] - licenses = record[0]['license'] - # FIXME: current TM has this as an int, but it seems a user might agree to - # more than one license. The database expects an array already, it'll just - # have a single entry. - # sql = f" UPDATE users SET licenses = licenses||{licenses} WHERE id={uid}" - sql = f" UPDATE users SET licenses = ARRAY[{licenses}] WHERE id={uid}" - #print(sql) - try: - result = db.dbcursor.execute(f"{sql};") - except: - log.error(f"Couldn't execute query {sql}") - - return True - -def interestsThread( +# async def licensesThread( +# data: list, +# db: PostgresClient, +# ): +# """Thread to handle importing + +# Args: +# data (list): The list of records to import +# db (PostgresClient): A database connection +# """ +# array = "licenses" +# column = "license" + +# # sql = f"CREATE SERVER IF NOT EXISTS pg_rep_db FOREIGN DATA WRAPPER dblink_fdw OPTIONS (dbname 'tm4');" +# # data = await db.pg.execute(sql) +# # sql = f"CREATE USER MAPPING IF NOT EXISTS FOR rob SERVER pg_rep_db OPTIONS ( user 'rob' ,password 'fu=br');" + +# # sql = "SELECT users.*,ARRAY[user_licenses.license] AS license INTO tmp FROM users JOIN dblink('pg_rep_db','SELECT * FROM user_licenses') AS user_licenses(user_id bigint, license int) ON users.id=user_licenses.user_id;" +# # print(sql) +# # result = await db.pg.execute(sql) + +# # #sql = f"ALTER TABLE users RENAME TO users_bak; ALTER TABLE tmp RENAME TO users;" +# # #result = await db.pg.execute(sql) + +# pbar = tqdm.tqdm(data) +# for record in pbar: +# entry = eval(record[0]) +# uid = entry['user'] +# licenses = entry['license'] +# # FIXME: current TM has this as an int, but it seems a user might agree to +# # more than one license. The database expects an array already, it'll just +# # have a single entry. +# sql = f" UPDATE users SET licenses = licenses||{licenses} WHERE id={uid}" +# sql = f"SELECT users.*,ARRAY[user_licenses.license] AS license INTO tmp FROM users JOIN user_licenses ON users.id=user_licenses.user;" +# print(sql) +# result = await db.pg.execute(f"{sql};") + +# return True + +async def interestsThread( interests: list, db: PostgresClient, ): @@ -87,19 +98,16 @@ def interestsThread( db (PostgresClient): A database connection """ data = dict() - pbar = tqdm(interests) + pbar = tqdm.tqdm(interests) for record in pbar: - uid = record[0]['user_id'] - interests = record[0]['interest_id'] - sql = f" UPDATE users SET interests = interests||{interests} WHERE id={uid}" - # print(sql) - try: - result = db.dbcursor.execute(f"{sql};") - except: - log.error(f"Couldn't execute query {sql}") + for id, array in record.items(): + sql = f" UPDATE users SET interests = interests||ARRAY{array} WHERE id={id}" + # print(sql) + result = await db.execute(sql) + return True -def favoritesThread( +async def favoritesThread( favorites: list, db: PostgresClient, ): @@ -112,14 +120,10 @@ def favoritesThread( data = dict() pbar = tqdm(favorites) for record in pbar: - uid = record[0] - projects = record[1][0] - sql = f" UPDATE users SET favorite_projects = ARRAY{projects} WHERE id={uid}" - # print(sql) - try: - result = db.dbcursor.execute(f"{sql};") - except: - log.error(f"Couldn't execute query {sql}") + for id, array in record.items(): + sql = f" UPDATE users SET favorite_projects = ARRAY{projects} WHERE id={uid}" + # print(sql) + result = await db.execute(f"{sql};") return True @@ -139,130 +143,137 @@ def __init__(self, self.pg = None self.profile = UsersTable() self.types = dir(tm_admin.types_tm) - super().__init__('users', dburi) + # super().__init__('users', dburi) + super().__init__('users') - def mergeInterests(self): + async def mergeInterests(self, + inpg: PostgresClient, + ): + """ + Merge the user_interests table from the Tasking Manager + + Args: + inuri (str): The input database + """ table = 'user_interests' + timer = Timer(initial_text="Merging user_interests table...", + text="merging liceneses table took {seconds:.0f}s", + logger=log.debug, + ) log.info(f"Merging interests table...") - # One database connection per thread - tmpg = list() - for i in range(0, cores + 1): - # FIXME: this shouldn't be hardcoded - tmpg.append(PostgresClient('localhost/tm_admin')) - pg = PostgresClient('localhost/tm4') - sql = f"SELECT row_to_json({table}) as row FROM {table}" + timer.start() + sql = "SELECT * FROM user_interests ORDER BY user_id" + result = await inpg.execute(sql) + + # FIXME: this SQL turns out to be painfully slow, and we can create the array + # in python faster. + # await self.copyTable(table, remote) + # await self.renameTable(table) + # sql = f"SELECT row_to_json({table}) as row FROM {table}" + # Not all records in this table have data # sql = f"SELECT u.user_id,ARRAY(SELECT ARRAY(SELECT c.interest_id FROM {table} c WHERE c.user_id = u.user_id)) AS user_id FROM {table} u;" - print(sql) - try: - result = pg.dbcursor.execute(sql) - except: - log.error(f"Couldn't execute query! {sql}") - return False - - result = pg.dbcursor.fetchall() + data = list() + entry = dict() + prev = None + # Restructure the data into a list, so we can more easily chop the data + # into multiple smaller pieces, one for each thread. + for record in result: + if prev == record['user_id']: + entry[record['user_id']].append(record['interest_id']) + else: + if len(entry) != 0: + data.append(entry) + prev = record['user_id'] + entry = {record['user_id']: [record['interest_id']]} + timer.stop() - entries = len(result) - log.debug(f"There are {entries} entries in {table}") + # await remote.pg.close() + #pg = PostgresClient() + entries = len(data) chunk = round(entries / cores) - - index = 0 - with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor: - # futures = list() - block = 0 - while block <= entries: - #log.debug(f"Dispatching Block %d:%d" % (block, block + chunk)) - #interestsThread(result, tmpg[0]) - executor.submit(interestsThread, result[block : block + chunk], tmpg[index]) - # futures.append(result) - block += chunk - index += 1 - # tqdm(futures, desc=f"Dispatching Block {block}:{block + chunk}", total=chunk): - # future.result() - executor.shutdown() - - # timer.stop + start = 0 + async with asyncio.TaskGroup() as tg: + for block in range(0, entries, chunk): + # for index in range(0, cores): + outpg = PostgresClient() + await outpg.connect('localhost/tm_admin') + log.debug(f"Dispatching thread {block}:{block + chunk}") + # await interestsThread(data, outpg) + task = tg.create_task(interestsThread(data[block:block + chunk], outpg)) + start += chunk + return True - def getPage(self, - offset: int, - count: int, - pg: PostgresClient, - table: str, - ): + async def mergeLicenses(self): """ - Return a page of data from the table. - - Args: - offset (int): The starting record - count (int): The number of records - pg (PostgresClient): Database connection for the input data - table (str): The table to query for data - - Returns: - (list): The results of the query + Merge data from the TM user_licenses table into TM Admin. The + fastest way to do a bulk update of a table is by copying the + remote database table into the local database, and then merging + into a new temporary table and then renaming it. """ - # It turns out to be much faster to use the columns specified in the - # SELECT statement, and construct our own dictionary than using row_to_json(). - #columns = "user_licenses.user, license" - - sql = f"SELECT row_to_json({table}) as row FROM {table} ORDER BY user LIMIT {count} OFFSET {offset}" - # sql = f"SELECT {columns} FROM {table} ORDER BY user LIMIT {count} OFFSET {offset}" - print(sql) - pg.dbcursor.execute(sql) - result = pg.dbcursor.fetchall() - # data = list() - # Since we're not using row_to_json(), build a data structure - # for record in result: - # table = dict(zip(columns, record)) - # data.append(table) - - return result - - def mergeLicenses(self): - """Merge data from the TM user_licenses table into TM Admin.""" table = 'user_licenses' # log.info(f"Merging licenses table...") - timer = Timer(initial_text="Merging licenses table...", - text="merging liceneses table took {seconds:.0f}s", + timer = Timer(initial_text="Merging user_licenses table...", + text="merging user_liceneses table took {seconds:.0f}s", logger=log.debug, ) - timer.start() - - sql = f"SELECT row_to_json({table}) as row FROM {table}" - # One database connection per thread - adminpg = list() - for i in range(0, cores + 1): - adminpg.append(PostgresClient('localhost/tm_admin')) + pg = PostgresClient() + await pg.connect('localhost/tm4') + self.columns = await pg.getColumns(table) + print(f"COLUMNS: {self.columns}") + await pg.pg.close() + await pg.connect('localhost/tm_admin') + + # await self.copyTable(table, pg) + # log.warning(f"Merging tables can take considerable time...") + + # cleanup old temporary tables + drop = ["DROP TABLE IF EXISTS users_bak", + "DROP TABLE IF EXISTS foo"] + # result = await pg.pg.executemany(drop) + sql = f"DROP TABLE IF EXISTS users_bak" + result = await pg.execute(sql) + sql = f"DROP TABLE IF EXISTS user_licenses" + result = await pg.execute(sql) + sql = f"DROP TABLE IF EXISTS new_users" + result = await pg.execute(sql) + + # We need to use DBLINK + sql = f"CREATE EXTENSION IF NOT EXISTS dblink;" + data = await pg.execute(sql) - # just one thread to read the data - tmpg = PostgresClient('localhost/tm4') - try: - result = tmpg.dbcursor.execute(sql) - except: - log.error(f"Couldn't execute query! {sql}") - - data = tmpg.dbcursor.fetchall() - entries = len(data) - log.debug(f"There are {entries} entries in {table}") - chunk = round(entries / cores) - - index = 0 - with concurrent.futures.ThreadPoolExecutor( - thread_name_prefix="users", - max_workers=cores) as executor: - index = 0 - for block in range(0, entries, chunk): - data = self.getPage(block, chunk, tmpg, table) - #result = licensesThread(data, adminpg[0]) - result = executor.submit(licensesThread, data, adminpg[index]) - index += 1 - executor.shutdown() - - timer.stop - return True + timer.start() + dbuser = pg.dburi["dbuser"] + dbpass = pg.dburi["dbpass"] + sql = f"CREATE SERVER IF NOT EXISTS pg_rep_db FOREIGN DATA WRAPPER dblink_fdw OPTIONS (dbname 'tm4');" + data = await pg.execute(sql) + + sql = f"CREATE USER MAPPING IF NOT EXISTS FOR {dbuser} SERVER pg_rep_db OPTIONS ( user '{dbuser}', password '{dbpass}');" + result = await pg.execute(sql) + + # Copy table from remote database so JOIN is faster when it's in the + # same database + log.warning(f"Copying a remote table is slow, but faster than remote access......") + sql = f"SELECT * INTO user_licenses FROM dblink('pg_rep_db','SELECT * FROM user_licenses') AS user_licenses(user_id bigint, license int)" + print(pg.dburi) + # print(sql) + result = await pg.execute(sql) + + # JOINing into a new table is much faster than doing an UPDATE + sql = f"SELECT users.*,ARRAY[user_licenses.license] INTO new_users FROM users JOIN user_licenses ON users.id=user_licenses.user_id" + result = await pg.execute(sql) + + # self.renameTable(table, pg) + # sql = f"ALTER TABLE users RENAME TO users_bak;" + # result = await pg.execute(sql) + # sql = f"ALTER TABLE new_users RENAME TO users;" + # result = await pg.execute(sql) + # sql = f"DROP TABLE IF EXISTS user_licenses" + # result = await pg.execute(sql) + timer.stop() - def mergeTeam(self): + async def mergeTeam(self): table = 'team_members' # FIXME: this shouldn't be hardcoded! log.info(f"Merging team members table...") @@ -272,11 +283,10 @@ def mergeTeam(self): sql = f"SELECT row_to_json({table}) as row FROM {table}" #print(sql) try: - result = pg.dbcursor.execute(sql) + result = await pg.execute(sql) except: log.error(f"Couldn't execute query! {sql}") - result = pg.dbcursor.fetchall() pbar = tqdm(result) for record in pbar: func = record[0]['function'] @@ -284,30 +294,29 @@ def mergeTeam(self): sql = f"UPDATE {self.table} SET team_members.team={record[0]['team_id']}, team_members.active={record[0]['active']}, team_members.function='{tmfunc.name}' WHERE id={record[0]['user_id']}" # print(f"{sql};") try: - result = self.pg.dbcursor.execute(sql) + result = await self.pg.execute(sql) except: log.error(f"Couldn't execute query! '{sql}'") timer.stop() return True - def mergeFavorites(self): + async def mergeFavorites(self): table = 'project_favorites' log.info(f"Merging favorites table...") # FIXME: this shouldn't be hardcoded! timer = Timer(text="merging favorites table took {seconds:.0f}s") timer.start() - pg = PostgresClient('localhost/tm4') + pg = PostgresClient() + await pg.connect('localhost/tm4') sql = f"SELECT u.user_id,ARRAY(SELECT ARRAY(SELECT c.project_id FROM {table} c WHERE c.user_id = u.user_id)) AS user_id FROM {table} u;" #sql = f"SELECT row_to_json({table}) as row FROM {table} ORDER BY user_id" # print(sql) try: - result = pg.dbcursor.execute(sql) + result = pg.execute(sql) except: log.error(f"Couldn't execute query! {sql}") - result = pg.dbcursor.fetchall() - entries = len(result) log.debug(f"There are {entries} entries in {table}") chunk = round(entries / cores) @@ -334,7 +343,7 @@ def mergeFavorites(self): return True # These are just convience wrappers to support the REST API. - def updateRole(self, + async def updateRole(self, id: int, role: Userrole, ): @@ -348,7 +357,7 @@ def updateRole(self, role = Userrole(role) return self.updateColumn(id, {'role': role.name}) - def updateMappingLevel(self, + async def updateMappingLevel(self, id: int, level: Mappinglevel, ): @@ -360,9 +369,10 @@ def updateMappingLevel(self, level (Mappinglevel): The new level. """ mlevel = Mappinglevel(level) - return self.updateColumn(id, {'mapping_level': mlevel.name}) + result = await self.updateColumn(id, {'mapping_level': mlevel.name}) + return result - def updateExpert(self, + async def updateExpert(self, id: int, mode: bool, ): @@ -373,9 +383,10 @@ def updateExpert(self, id (int): The users ID. mode (bool): The new mode.. """ - return self.updateColumn(id, {'expert_mode': mode}) + result = await self.updateColumn(id, {'expert_mode': mode}) + return result - def getRegistered(self, + async def getRegistered(self, start: datetime, end: datetime, ): @@ -391,12 +402,24 @@ def getRegistered(self, """ where = f" date_registered > '{start}' and date_registered < '{end}'" - return self.getByWhere(where) + result = self.getByWhere(where) + return result - def mergeAuxTables(self): + async def mergeAuxTables(self, + inuri: str, + outuri: str, + ): """ Merge more tables from TM into the unified users table. + + Args: + inturi (str): The input database + outuri (str): The output database """ + await self.connect(outuri) + inpg = PostgresClient() + await inpg.connect(inuri) + # if self.mergeTeam(): # log.info("UserDB.mergeTeams worked!") @@ -404,17 +427,18 @@ def mergeAuxTables(self): # log.info("UserDB.mergeFavorites worked!") # These may take a long time to complete - #if self.mergeInterests(): - # log.info("UserDB.mergeInterests worked!") + await self.mergeInterests(inpg) + log.info("UserDB.mergeInterests worked!") - if self.mergeLicenses(): - log.info("UserDB.mergeLicenses worked!") + # result = await self.mergeLicenses() + # log.info("UserDB.mergeLicenses worked!") -def main(): +async def main(): """This main function lets this class be run standalone by a bash script.""" parser = argparse.ArgumentParser() parser.add_argument("-v", "--verbose", nargs="?", const="0", help="verbose output") - parser.add_argument("-u", "--uri", default='localhost/tm_admin', help="Database URI") + parser.add_argument("-i", "--inuri", default='localhost/tm4', help="Database URI") + parser.add_argument("-o", "--outuri", default='localhost/tm_admin', help="Database URI") # parser.add_argument("-r", "--reset", help="Reset Sequences") args = parser.parse_args() @@ -434,8 +458,9 @@ def main(): stream=sys.stdout, ) - user = UsersDB(args.uri) - user.mergeAuxTables() + user = UsersDB() + # user.connect(args.uri) + await user.mergeAuxTables(args.inuri, args.outuri) # user.resetSequence() # all = user.getAll() @@ -457,4 +482,6 @@ def main(): if __name__ == "__main__": """This is just a hook so this file can be run standalone during development.""" - main() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(main())