Skip to content

Commit

Permalink
timing face cluster (#406)
Browse files Browse the repository at this point in the history
* timing face cluster

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: zheng.shen <[email protected]>
  • Loading branch information
shenzheng-1 and zheng.shen authored Oct 29, 2024
1 parent 2802c3a commit f530357
Show file tree
Hide file tree
Showing 9 changed files with 254 additions and 90 deletions.
3 changes: 3 additions & 0 deletions app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from seafevents.seafevent_server.seafevent_server import SeafEventServer
from seafevents.app.config import ENABLE_METADATA_MANAGEMENT
from seafevents.seasearch.index_task.filename_index_updater import RepoFilenameIndexUpdater
from seafevents.repo_metadata.face_recognition_updater import RepoFaceClusterUpdater


class App(object):
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self, config, ccnet_config, seafile_config,
self._index_master = RepoMetadataIndexMaster(config)
self._index_worker = RepoMetadataIndexWorker(config)
self._slow_task_handler = SlowTaskHandler(config)
self._repo_face_cluster_updater = RepoFaceClusterUpdater(config)
self._repo_filename_index_updater = RepoFilenameIndexUpdater(config)

def serve_forever(self):
Expand All @@ -65,4 +67,5 @@ def serve_forever(self):
self._index_master.start()
self._index_worker.start()
self._slow_task_handler.start()
self._repo_face_cluster_updater.start()
self._repo_filename_index_updater.start()
21 changes: 21 additions & 0 deletions repo_data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,20 @@ def _get_repo_id_commit_id(self, start, count):
finally:
session.close()

def _get_mtime_by_repo_ids(self, repo_ids):
session = self.db_session()
try:
if len(repo_ids) == 1:
cmd = """SELECT repo_id, update_time FROM RepoInfo WHERE repo_id = '%s'""" % repo_ids[0]
else:
cmd = """SELECT repo_id, update_time FROM RepoInfo WHERE repo_id IN {}""".format(tuple(repo_ids))
res = session.execute(text(cmd)).fetchall()
return res
except Exception as e:
raise e
finally:
session.close()

def _get_all_trash_repo_list(self):
session = self.db_session()
try:
Expand Down Expand Up @@ -114,6 +128,13 @@ def get_all_repo_list(self):
logger.error(e)
return self._get_all_repo_list()

def get_mtime_by_repo_ids(self, repo_ids):
try:
return self._get_mtime_by_repo_ids(repo_ids)
except Exception as e:
logger.error(e)
return self._get_mtime_by_repo_ids(repo_ids)

def get_all_trash_repo_list(self):
try:
return self._get_all_trash_repo_list()
Expand Down
133 changes: 91 additions & 42 deletions repo_metadata/face_recognition_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
import logging
from datetime import datetime
import numpy as np

from seafevents.utils import get_opt_from_conf_or_env
from seafevents.db import init_db_session_class
from seafevents.repo_metadata.metadata_server_api import MetadataServerAPI
from seafevents.repo_metadata.image_embedding_api import ImageEmbeddingAPI
from seafevents.repo_metadata.utils import METADATA_TABLE, FACES_TABLE, query_metadata_rows, get_face_embeddings, face_compare
from seafevents.repo_metadata.utils import METADATA_TABLE, FACES_TABLE, query_metadata_rows, get_face_embeddings, get_faces_rows, get_cluster_by_center, update_face_cluster_time
from seafevents.repo_metadata.constants import METADATA_OP_LIMIT

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -33,12 +36,6 @@ def init_face_recognition(self, repo_id):
if not query_result:
return

metadata = self.metadata_server_api.get_metadata(repo_id)
tables = metadata.get('tables', [])
if not tables:
return
faces_table_id = [table['id'] for table in tables if table['name'] == FACES_TABLE.name][0]

obj_id_to_rows = {}
for item in query_result:
obj_id = item[METADATA_TABLE.columns.obj_id.name]
Expand All @@ -47,42 +44,94 @@ def init_face_recognition(self, repo_id):
obj_id_to_rows[obj_id].append(item)

obj_ids = list(obj_id_to_rows.keys())
known_faces = []
for obj_id in obj_ids:
records = obj_id_to_rows.get(obj_id, [])
known_faces = self.face_recognition(obj_id, records, repo_id, faces_table_id, known_faces)

def face_recognition(self, obj_id, records, repo_id, faces_table_id, used_faces):
embeddings = self.image_embedding_api.face_embeddings(repo_id, [obj_id]).get('data', [])
if not embeddings:
return used_faces
embedding = embeddings[0]
face_embeddings = embedding['embeddings']
recognized_faces = []
for face_embedding in face_embeddings:
face = face_compare(face_embedding, used_faces, 1.24)
if not face:
row = {
FACES_TABLE.columns.vector.name: json.dumps(face_embedding),
}
result = self.metadata_server_api.insert_rows(repo_id, faces_table_id, [row])
row_id = result.get('row_ids')[0]
used_faces.append({
FACES_TABLE.columns.id.name: row_id,
FACES_TABLE.columns.vector.name: json.dumps(face_embedding),
})
row_id_map = {
row_id: [item.get(METADATA_TABLE.columns.id.name) for item in records]
updated_rows = []
for i in range(0, len(obj_ids), 50):
obj_ids_batch = obj_ids[i: i + 50]
result = self.image_embedding_api.face_embeddings(repo_id, obj_ids_batch).get('data', [])
if not result:
continue

for item in result:
obj_id = item['obj_id']
face_embeddings = item['embeddings']
for row in obj_id_to_rows.get(obj_id, []):
row_id = row[METADATA_TABLE.columns.id.name]
updated_rows.append({
METADATA_TABLE.columns.id.name: row_id,
METADATA_TABLE.columns.face_vectors.name: json.dumps(face_embeddings),
})
if len(updated_rows) >= METADATA_OP_LIMIT:
self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows)
updated_rows = []

if updated_rows:
self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows)

self.face_cluster(repo_id)

def face_cluster(self, repo_id):
try:
from sklearn.cluster import HDBSCAN
except ImportError:
logger.warning('Package scikit-learn is not installed. ')
return

current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
update_face_cluster_time(self._db_session_class, repo_id, current_time)

sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.face_vectors.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.face_vectors.name}` IS NOT NULL'
query_result = query_metadata_rows(repo_id, self.metadata_server_api, sql)
if not query_result:
return

metadata = self.metadata_server_api.get_metadata(repo_id)
tables = metadata.get('tables', [])
if not tables:
return
faces_table_id = [table['id'] for table in tables if table['name'] == FACES_TABLE.name][0]

vectors = []
row_ids = []
for item in query_result:
row_id = item[METADATA_TABLE.columns.id.name]
face_vectors = json.loads(item[METADATA_TABLE.columns.face_vectors.name])
for face_vector in face_vectors:
vectors.append(face_vector)
row_ids.append(row_id)

old_cluster = get_faces_rows(repo_id, self.metadata_server_api)
clt = HDBSCAN(min_cluster_size=5)
clt.fit(vectors)

label_ids = np.unique(clt.labels_)
for label_id in label_ids:
idxs = np.where(clt.labels_ == label_id)[0]
related_row_ids = [row_ids[i] for i in idxs]
if label_id != -1:
cluster_center = np.mean([vectors[i] for i in idxs], axis=0)
face_row = {
FACES_TABLE.columns.vector.name: json.dumps(cluster_center.tolist()),
}
self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map)
cluster = get_cluster_by_center(cluster_center, old_cluster)
if cluster:
cluster_id = cluster[FACES_TABLE.columns.id.name]
old_cluster = [item for item in old_cluster if item[FACES_TABLE.columns.id.name] != cluster_id]
face_row[FACES_TABLE.columns.id.name] = cluster_id
self.metadata_server_api.update_rows(repo_id, faces_table_id, [face_row])
row_id_map = {
cluster_id: related_row_ids
}
self.metadata_server_api.update_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map)
continue
else:
recognized_faces.append(face)
face_row = dict()

if recognized_faces:
row_ids = [item[FACES_TABLE.columns.id.name] for item in recognized_faces]
row_id_map = dict()
for row in records:
row_id_map[row[METADATA_TABLE.columns.id.name]] = row_ids
self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, METADATA_TABLE.id, row_id_map)
result = self.metadata_server_api.insert_rows(repo_id, faces_table_id, [face_row])
row_id = result.get('row_ids')[0]
row_id_map = {
row_id: related_row_ids
}
self.metadata_server_api.insert_link(repo_id, FACES_TABLE.link_id, faces_table_id, row_id_map)

return used_faces
need_delete_row_ids = [item[FACES_TABLE.columns.id.name] for item in old_cluster]
self.metadata_server_api.delete_rows(repo_id, faces_table_id, need_delete_row_ids)
73 changes: 73 additions & 0 deletions repo_metadata/face_recognition_updater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging
from threading import Thread

from apscheduler.triggers.cron import CronTrigger
from apscheduler.schedulers.gevent import GeventScheduler
from seafevents.db import init_db_session_class
from seafevents.repo_metadata.face_recognition_manager import FaceRecognitionManager
from seafevents.repo_metadata.utils import get_face_recognition_enabled_repo_list, update_face_cluster_time
from seafevents.repo_data import repo_data

logger = logging.getLogger(__name__)


class RepoFaceClusterUpdater(object):
def __init__(self, config):
self._face_recognition_manager = FaceRecognitionManager(config)
self._session = init_db_session_class(config)

def start(self):
logging.info('Start to update face cluster')
FaceClusterUpdaterTimer(
self._face_recognition_manager,
self._session,
).start()


def update_face_cluster(face_recognition_manager, session):
start, count = 0, 1000
while True:
try:
repos = get_face_recognition_enabled_repo_list(session, start, count)
except Exception as e:
logger.error("Error: %s" % e)
return
start += 1000

if len(repos) == 0:
break

repo_ids = [repo[0] for repo in repos]
repos_mtime = repo_data.get_mtime_by_repo_ids(repo_ids)
repo_id_to_mtime = {repo[0]: repo[1] for repo in repos_mtime}

for repo in repos:
repo_id = repo[0]
last_face_cluster_time = repo[1]
mtime = repo_id_to_mtime.get(repo_id)
if not mtime:
continue

if last_face_cluster_time and int(mtime) <= int(last_face_cluster_time.timestamp()):
continue
face_recognition_manager.face_cluster(repo_id)

logger.info("Finish update face cluster")


class FaceClusterUpdaterTimer(Thread):
def __init__(self, face_recognition_manager, session):
super(FaceClusterUpdaterTimer, self).__init__()
self.face_recognition_manager = face_recognition_manager
self.session = session

def run(self):
sched = GeventScheduler()
logging.info('Start to update face cluster...')
try:
sched.add_job(update_face_cluster, CronTrigger(day_of_week='*'),
args=(self.face_recognition_manager, self.session))
except Exception as e:
logging.exception('periodical update face cluster error: %s', e)

sched.start()
13 changes: 9 additions & 4 deletions repo_metadata/slow_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from seafevents.utils import get_opt_from_conf_or_env
from seafevents.repo_metadata.metadata_server_api import MetadataServerAPI
from seafevents.repo_metadata.image_embedding_api import ImageEmbeddingAPI
from seafevents.repo_metadata.utils import add_file_details
from seafevents.seafevent_server.face_recognition_task_manager import face_recognition_task_manager
from seafevents.repo_metadata.utils import add_file_details, get_repo_face_recognition_status
from seafevents.db import init_db_session_class

logger = logging.getLogger(__name__)

Expand All @@ -28,6 +28,7 @@ def __init__(self, config):
self.mq_port = 6379
self.mq_password = ''
self.worker_num = 3
self.session = init_db_session_class(config)
self._parse_config(config)

self.mq = get_mq(self.mq_server, self.mq_port, self.mq_password)
Expand Down Expand Up @@ -98,6 +99,10 @@ def extract_file_info(self, repo_id, data):

try:
obj_ids = data.get('obj_ids')
add_file_details(repo_id, obj_ids, self.metadata_server_api, face_recognition_task_manager)
face_recognition_status = get_repo_face_recognition_status(repo_id, self.session)
image_embedding_api = self.image_embedding_api if face_recognition_status else None
add_file_details(repo_id, obj_ids, self.metadata_server_api, image_embedding_api)
except Exception as e:
logger.exception('repo: %s, update metadata image info error: %s', repo_id, e)
logger.exception('repo: %s, update metadata file info error: %s', repo_id, e)

logger.info('%s finish extract file info repo %s' % (threading.currentThread().getName(), repo_id))
Loading

0 comments on commit f530357

Please sign in to comment.