diff --git a/app/app.py b/app/app.py index 8dd5148d..817d12ab 100644 --- a/app/app.py +++ b/app/app.py @@ -10,6 +10,7 @@ from seafevents.seafevent_server.seafevent_server import SeafEventServer from seafevents.app.config import ENABLE_METADATA_MANAGEMENT from seafevents.seasearch.index_task.filename_index_updater import RepoFilenameIndexUpdater +from seafevents.repo_metadata.face_recognition_updater import RepoFaceClusterUpdater class App(object): @@ -41,6 +42,7 @@ def __init__(self, config, ccnet_config, seafile_config, self._index_master = RepoMetadataIndexMaster(config) self._index_worker = RepoMetadataIndexWorker(config) self._slow_task_handler = SlowTaskHandler(config) + self._repo_face_cluster_updater = RepoFaceClusterUpdater(config) self._repo_filename_index_updater = RepoFilenameIndexUpdater(config) def serve_forever(self): @@ -65,4 +67,5 @@ def serve_forever(self): self._index_master.start() self._index_worker.start() self._slow_task_handler.start() + self._repo_face_cluster_updater.start() self._repo_filename_index_updater.start() diff --git a/repo_data/__init__.py b/repo_data/__init__.py index ae09e1e3..15d6e3c5 100644 --- a/repo_data/__init__.py +++ b/repo_data/__init__.py @@ -39,6 +39,20 @@ def _get_repo_id_commit_id(self, start, count): finally: session.close() + def _get_mtime_by_repo_ids(self, repo_ids): + session = self.db_session() + try: + if len(repo_ids) == 1: + cmd = """SELECT repo_id, update_time FROM RepoInfo WHERE repo_id = '%s'""" % repo_ids[0] + else: + cmd = """SELECT repo_id, update_time FROM RepoInfo WHERE repo_id IN {}""".format(tuple(repo_ids)) + res = session.execute(text(cmd)).fetchall() + return res + except Exception as e: + raise e + finally: + session.close() + def _get_all_trash_repo_list(self): session = self.db_session() try: @@ -114,6 +128,13 @@ def get_all_repo_list(self): logger.error(e) return self._get_all_repo_list() + def get_mtime_by_repo_ids(self, repo_ids): + try: + return self._get_mtime_by_repo_ids(repo_ids) + except Exception as e: + logger.error(e) + return self._get_mtime_by_repo_ids(repo_ids) + def get_all_trash_repo_list(self): try: return self._get_all_trash_repo_list() diff --git a/repo_metadata/face_recognition_manager.py b/repo_metadata/face_recognition_manager.py index 8ae9b13d..de767e86 100644 --- a/repo_metadata/face_recognition_manager.py +++ b/repo_metadata/face_recognition_manager.py @@ -1,11 +1,14 @@ import json import logging +from datetime import datetime +import numpy as np from seafevents.utils import get_opt_from_conf_or_env from seafevents.db import init_db_session_class from seafevents.repo_metadata.metadata_server_api import MetadataServerAPI from seafevents.repo_metadata.image_embedding_api import ImageEmbeddingAPI -from seafevents.repo_metadata.utils import METADATA_TABLE, FACES_TABLE, query_metadata_rows, get_face_embeddings, face_compare +from seafevents.repo_metadata.utils import METADATA_TABLE, FACES_TABLE, query_metadata_rows, get_face_embeddings, get_faces_rows, get_cluster_by_center, update_face_cluster_time +from seafevents.repo_metadata.constants import METADATA_OP_LIMIT logger = logging.getLogger(__name__) @@ -33,12 +36,6 @@ def init_face_recognition(self, repo_id): if not query_result: return - metadata = self.metadata_server_api.get_metadata(repo_id) - tables = metadata.get('tables', []) - if not tables: - return - faces_table_id = [table['id'] for table in tables if table['name'] == FACES_TABLE.name][0] - obj_id_to_rows = {} for item in query_result: obj_id = item[METADATA_TABLE.columns.obj_id.name] @@ -47,42 +44,94 @@ def init_face_recognition(self, repo_id): obj_id_to_rows[obj_id].append(item) obj_ids = list(obj_id_to_rows.keys()) - known_faces = [] - for obj_id in obj_ids: - records = obj_id_to_rows.get(obj_id, []) - known_faces = self.face_recognition(obj_id, records, repo_id, faces_table_id, known_faces) - - def face_recognition(self, obj_id, records, repo_id, faces_table_id, used_faces): - embeddings = self.image_embedding_api.face_embeddings(repo_id, [obj_id]).get('data', []) - if not embeddings: - return used_faces - embedding = embeddings[0] - face_embeddings = embedding['embeddings'] - recognized_faces = [] - for face_embedding in face_embeddings: - face = face_compare(face_embedding, used_faces, 1.24) - if not face: - row = { - FACES_TABLE.columns.vector.name: json.dumps(face_embedding), - } - result = self.metadata_server_api.insert_rows(repo_id, faces_table_id, [row]) - row_id = result.get('row_ids')[0] - used_faces.append({ - FACES_TABLE.columns.id.name: row_id, - FACES_TABLE.columns.vector.name: json.dumps(face_embedding), - }) - row_id_map = { - row_id: [item.get(METADATA_TABLE.columns.id.name) for item in records] + updated_rows = [] + for i in range(0, len(obj_ids), 50): + obj_ids_batch = obj_ids[i: i + 50] + result = self.image_embedding_api.face_embeddings(repo_id, obj_ids_batch).get('data', []) + if not result: + continue + + for item in result: + obj_id = item['obj_id'] + face_embeddings = item['embeddings'] + for row in obj_id_to_rows.get(obj_id, []): + row_id = row[METADATA_TABLE.columns.id.name] + updated_rows.append({ + METADATA_TABLE.columns.id.name: row_id, + METADATA_TABLE.columns.face_vectors.name: json.dumps(face_embeddings), + }) + if len(updated_rows) >= METADATA_OP_LIMIT: + self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows) + updated_rows = [] + + if updated_rows: + self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows) + + self.face_cluster(repo_id) + + def face_cluster(self, repo_id): + try: + from sklearn.cluster import HDBSCAN + except ImportError: + logger.warning('Package scikit-learn is not installed. ') + return + + current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + update_face_cluster_time(self._db_session_class, repo_id, current_time) + + sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.face_vectors.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.face_vectors.name}` IS NOT NULL' + query_result = query_metadata_rows(repo_id, self.metadata_server_api, sql) + if not query_result: + return + + metadata = self.metadata_server_api.get_metadata(repo_id) + tables = metadata.get('tables', []) + if not tables: + return + faces_table_id = [table['id'] for table in tables if table['name'] == FACES_TABLE.name][0] + + vectors = [] + row_ids = [] + for item in query_result: + row_id = item[METADATA_TABLE.columns.id.name] + face_vectors = json.loads(item[METADATA_TABLE.columns.face_vectors.name]) + for face_vector in face_vectors: + vectors.append(face_vector) + row_ids.append(row_id) + + old_cluster = get_faces_rows(repo_id, self.metadata_server_api) + clt = HDBSCAN(min_cluster_size=5) + clt.fit(vectors) + + label_ids = np.unique(clt.labels_) + for label_id in label_ids: + idxs = np.where(clt.labels_ == label_id)[0] + related_row_ids = [row_ids[i] for i in idxs] + if label_id != -1: + cluster_center = np.mean([vectors[i] for i in idxs], axis=0) + face_row = { + FACES_TABLE.columns.vector.name: json.dumps(cluster_center.tolist()), } - self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map) + cluster = get_cluster_by_center(cluster_center, old_cluster) + if cluster: + cluster_id = cluster[FACES_TABLE.columns.id.name] + old_cluster = [item for item in old_cluster if item[FACES_TABLE.columns.id.name] != cluster_id] + face_row[FACES_TABLE.columns.id.name] = cluster_id + self.metadata_server_api.update_rows(repo_id, faces_table_id, [face_row]) + row_id_map = { + cluster_id: related_row_ids + } + self.metadata_server_api.update_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map) + continue else: - recognized_faces.append(face) + face_row = dict() - if recognized_faces: - row_ids = [item[FACES_TABLE.columns.id.name] for item in recognized_faces] - row_id_map = dict() - for row in records: - row_id_map[row[METADATA_TABLE.columns.id.name]] = row_ids - self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, METADATA_TABLE.id, row_id_map) + result = self.metadata_server_api.insert_rows(repo_id, faces_table_id, [face_row]) + row_id = result.get('row_ids')[0] + row_id_map = { + row_id: related_row_ids + } + self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map) - return used_faces + need_delete_row_ids = [item[FACES_TABLE.columns.id.name] for item in old_cluster] + self.metadata_server_api.delete_rows(repo_id, faces_table_id, need_delete_row_ids) diff --git a/repo_metadata/face_recognition_updater.py b/repo_metadata/face_recognition_updater.py new file mode 100644 index 00000000..b1e82579 --- /dev/null +++ b/repo_metadata/face_recognition_updater.py @@ -0,0 +1,73 @@ +import logging +from threading import Thread + +from apscheduler.triggers.cron import CronTrigger +from apscheduler.schedulers.gevent import GeventScheduler +from seafevents.db import init_db_session_class +from seafevents.repo_metadata.face_recognition_manager import FaceRecognitionManager +from seafevents.repo_metadata.utils import get_face_recognition_enabled_repo_list, update_face_cluster_time +from seafevents.repo_data import repo_data + +logger = logging.getLogger(__name__) + + +class RepoFaceClusterUpdater(object): + def __init__(self, config): + self._face_recognition_manager = FaceRecognitionManager(config) + self._session = init_db_session_class(config) + + def start(self): + logging.info('Start to update face cluster') + FaceClusterUpdaterTimer( + self._face_recognition_manager, + self._session, + ).start() + + +def update_face_cluster(face_recognition_manager, session): + start, count = 0, 1000 + while True: + try: + repos = get_face_recognition_enabled_repo_list(session, start, count) + except Exception as e: + logger.error("Error: %s" % e) + return + start += 1000 + + if len(repos) == 0: + break + + repo_ids = [repo[0] for repo in repos] + repos_mtime = repo_data.get_mtime_by_repo_ids(repo_ids) + repo_id_to_mtime = {repo[0]: repo[1] for repo in repos_mtime} + + for repo in repos: + repo_id = repo[0] + last_face_cluster_time = repo[1] + mtime = repo_id_to_mtime.get(repo_id) + if not mtime: + continue + + if last_face_cluster_time and int(mtime) <= int(last_face_cluster_time.timestamp()): + continue + face_recognition_manager.face_cluster(repo_id) + + logger.info("Finish update face cluster") + + +class FaceClusterUpdaterTimer(Thread): + def __init__(self, face_recognition_manager, session): + super(FaceClusterUpdaterTimer, self).__init__() + self.face_recognition_manager = face_recognition_manager + self.session = session + + def run(self): + sched = GeventScheduler() + logging.info('Start to update face cluster...') + try: + sched.add_job(update_face_cluster, CronTrigger(day_of_week='*'), + args=(self.face_recognition_manager, self.session)) + except Exception as e: + logging.exception('periodical update face cluster error: %s', e) + + sched.start() diff --git a/repo_metadata/slow_task_handler.py b/repo_metadata/slow_task_handler.py index 63c8a32b..ec827fe8 100644 --- a/repo_metadata/slow_task_handler.py +++ b/repo_metadata/slow_task_handler.py @@ -9,8 +9,8 @@ from seafevents.utils import get_opt_from_conf_or_env from seafevents.repo_metadata.metadata_server_api import MetadataServerAPI from seafevents.repo_metadata.image_embedding_api import ImageEmbeddingAPI -from seafevents.repo_metadata.utils import add_file_details -from seafevents.seafevent_server.face_recognition_task_manager import face_recognition_task_manager +from seafevents.repo_metadata.utils import add_file_details, get_repo_face_recognition_status +from seafevents.db import init_db_session_class logger = logging.getLogger(__name__) @@ -28,6 +28,7 @@ def __init__(self, config): self.mq_port = 6379 self.mq_password = '' self.worker_num = 3 + self.session = init_db_session_class(config) self._parse_config(config) self.mq = get_mq(self.mq_server, self.mq_port, self.mq_password) @@ -98,6 +99,10 @@ def extract_file_info(self, repo_id, data): try: obj_ids = data.get('obj_ids') - add_file_details(repo_id, obj_ids, self.metadata_server_api, face_recognition_task_manager) + face_recognition_status = get_repo_face_recognition_status(repo_id, self.session) + image_embedding_api = self.image_embedding_api if face_recognition_status else None + add_file_details(repo_id, obj_ids, self.metadata_server_api, image_embedding_api) except Exception as e: - logger.exception('repo: %s, update metadata image info error: %s', repo_id, e) + logger.exception('repo: %s, update metadata file info error: %s', repo_id, e) + + logger.info('%s finish extract file info repo %s' % (threading.currentThread().getName(), repo_id)) diff --git a/repo_metadata/utils.py b/repo_metadata/utils.py index f8242b5a..d3d53cd3 100644 --- a/repo_metadata/utils.py +++ b/repo_metadata/utils.py @@ -8,6 +8,7 @@ import numpy as np from datetime import timedelta, timezone, datetime +from sqlalchemy.sql import text from seafobj import commit_mgr, fs_mgr @@ -38,20 +39,26 @@ def get_file_type_ext_by_name(filename): return file_type, file_ext -def face_compare(face, known_faces, threshold): - for known_face in known_faces: - if feature_distance(face, json.loads(known_face[FACES_TABLE.columns.vector.name]), threshold): - return known_face - return None - - -def feature_distance(feature1, feature2, threshold): +def feature_distance(feature1, feature2): diff = np.subtract(feature1, feature2) dist = np.sum(np.square(diff), 0) - if dist < threshold: - return True - else: - return False + return dist + + +def get_cluster_by_center(center, clusters): + min_distance = float('inf') + nearest_cluster = None + for cluster in clusters: + vector = cluster.get(FACES_TABLE.columns.vector.name) + if not vector: + continue + + vector = json.loads(vector) + distance = feature_distance(center, vector) + if distance < 1 and distance < min_distance: + min_distance = distance + nearest_cluster = cluster + return nearest_cluster def is_valid_datetime(date_string, format): @@ -62,6 +69,12 @@ def is_valid_datetime(date_string, format): return False +def get_faces_rows(repo_id, metadata_server_api): + sql = f'SELECT * FROM `{FACES_TABLE.name}`' + query_result = query_metadata_rows(repo_id, metadata_server_api, sql) + return query_result if query_result else [] + + def get_file_content(repo_id, obj_id, limit=-1): f = fs_mgr.load_seafile(repo_id, 1, obj_id) content = f.get_content(limit) @@ -147,7 +160,7 @@ def get_video_details(content): return details, location -def add_file_details(repo_id, obj_ids, metadata_server_api, face_recognition_task_manager, embedding_faces=True): +def add_file_details(repo_id, obj_ids, metadata_server_api, image_embedding_api=None): all_updated_rows = [] query_result = get_metadata_by_obj_ids(repo_id, obj_ids, metadata_server_api) if not query_result: @@ -160,29 +173,6 @@ def add_file_details(repo_id, obj_ids, metadata_server_api, face_recognition_tas obj_id_to_rows[obj_id] = [] obj_id_to_rows[obj_id].append(item) - if embedding_faces: - metadata = metadata_server_api.get_metadata(repo_id) - tables = metadata.get('tables', []) - if not tables: - return [] - faces_table_id = [table['id'] for table in tables if table['name'] == FACES_TABLE.name] - faces_table_id = faces_table_id[0] if faces_table_id else None - if faces_table_id: - sql = f'SELECT * FROM `{FACES_TABLE.name}`' - known_faces = query_metadata_rows(repo_id, metadata_server_api, sql) - used_faces = [] - no_used_face_row_ids = [] - for item in known_faces: - if item.get(FACES_TABLE.columns.photo_links.name): - used_faces.append(item) - else: - no_used_face_row_ids.append(item[FACES_TABLE.columns.id.name]) - if no_used_face_row_ids: - metadata_server_api.delete_rows(repo_id, faces_table_id, no_used_face_row_ids) - known_faces = used_faces - else: - known_faces = [] - updated_rows = [] columns = metadata_server_api.list_columns(repo_id, METADATA_TABLE.id).get('columns', []) capture_time_column = [column for column in columns if column.get('key') == PrivatePropertyKeys.CAPTURE_TIME] @@ -201,10 +191,12 @@ def add_file_details(repo_id, obj_ids, metadata_server_api, face_recognition_tas limit = 100000 if suffix == 'mp4' else -1 content = get_file_content(repo_id, obj_id, limit) if file_type == '_picture': - if embedding_faces and faces_table_id: - records = obj_id_to_rows.get(obj_id, []) - known_faces = face_recognition_task_manager.face_recognition(obj_id, records, repo_id, faces_table_id, known_faces) update_row = add_image_detail_row(row_id, content, has_capture_time_column) + if image_embedding_api and not row.get(METADATA_TABLE.columns.face_vectors.name): + result = image_embedding_api.face_embeddings(repo_id, [obj_id]).get('data', []) + if result: + face_embeddings = result[0]['embeddings'] + update_row[METADATA_TABLE.columns.face_vectors.name] = json.dumps(face_embeddings) elif file_type == '_video': update_row = add_video_detail_row(row_id, content, has_capture_time_column) else: @@ -335,6 +327,29 @@ def get_face_embeddings(repo_id, image_embedding_api, obj_ids): return embeddings +def get_repo_face_recognition_status(repo_id, session): + with session() as session: + sql = "SELECT face_recognition_enabled FROM repo_metadata WHERE repo_id='%s'" % repo_id + record = session.execute(text(sql)).fetchone() + + return record[0] if record else None + + +def get_face_recognition_enabled_repo_list(session, start, count): + with session() as session: + cmd = """SELECT repo_id, last_face_cluster_time FROM repo_metadata WHERE face_recognition_enabled = True limit :start, :count""" + res = session.execute(text(cmd), {'start': start, 'count': count}).fetchall() + + return res + + +def update_face_cluster_time(session, repo_id, update_time): + with session() as session: + cmd = """UPDATE repo_metadata SET last_face_cluster_time = :update_time WHERE repo_id = :repo_id""" + session.execute(text(cmd), {'update_time': update_time, 'repo_id': repo_id}) + session.commit() + + class MetadataTable(object): def __init__(self, table_id, name): self.id = table_id @@ -366,6 +381,7 @@ def __init__(self): self.collaborator = MetadataColumn('_collaborators', '_collaborators', 'collaborator') self.owner = MetadataColumn('_owner', '_owner', 'collaborator') + self.face_vectors = MetadataColumn('_face_vectors', '_face_vectors', 'long-text') self.face_links = MetadataColumn('_face_links', '_face_links', 'link') diff --git a/requirements.txt b/requirements.txt index a312c710..b11bb6e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,4 @@ gevent==24.2.* Flask==3.0.* apscheduler==3.10.* pyexiftool==0.4.* -numpy==1.24.* +scikit-learn==1.3.* diff --git a/seafevent_server/face_recognition_task_manager.py b/seafevent_server/face_recognition_task_manager.py index 3388e181..1047f181 100644 --- a/seafevent_server/face_recognition_task_manager.py +++ b/seafevent_server/face_recognition_task_manager.py @@ -51,9 +51,6 @@ def query_status(self, task_id): return True, task_result[6:] return False, None - def face_recognition(self, obj_id, records, repo_id, faces_table_id, used_faces): - return self.face_recognition_manager.face_recognition(obj_id, records, repo_id, faces_table_id, used_faces) - def threads_is_alive(self): info = {} for t in self.threads: diff --git a/seafevent_server/request_handler.py b/seafevent_server/request_handler.py index 9e47f0a4..880196b9 100644 --- a/seafevent_server/request_handler.py +++ b/seafevent_server/request_handler.py @@ -204,6 +204,6 @@ def extract_file_details(): return {'error_msg': 'repo_id invalid.'}, 400 metadata_server_api = MetadataServerAPI('seafevents') - details = add_file_details(repo_id, obj_ids, metadata_server_api, face_recognition_task_manager, embedding_faces=False) + details = add_file_details(repo_id, obj_ids, metadata_server_api) return {'details': details}, 200