From ab7e0eace9bf7bb271da9431faf31975bb5e8cff Mon Sep 17 00:00:00 2001 From: cir9no <44470218+cir9no@users.noreply.github.com> Date: Mon, 8 Jul 2024 00:06:59 +0800 Subject: [PATCH 1/4] feat: integrate-seafile-ai-search --- app/app.py | 5 +- app/config.py | 1 + mysql.sql | 9 + repo_data/db.py | 2 +- requirements.txt | 3 + seafevent_server/request_handler.py | 187 ++++++++++ seafevent_server/seafevent_server.py | 14 + semantic_search/config.py | 71 ++++ semantic_search/db.py | 81 ++++ semantic_search/index_store/extract.py | 207 +++++++++++ semantic_search/index_store/index_manager.py | 166 +++++++++ semantic_search/index_store/models.py | 25 ++ .../index_store/repo_file_index.py | 216 +++++++++++ .../index_store/repo_file_name_index.py | 351 ++++++++++++++++++ .../index_store/repo_status_index.py | 158 ++++++++ semantic_search/index_store/utils.py | 134 +++++++ .../index_task/filename_index_updater.py | 88 +++++ .../index_task/index_task_manager.py | 207 +++++++++++ .../script/portalocker/__init__.py | 3 + .../script/portalocker/portalocker.py | 143 +++++++ semantic_search/script/portalocker/utils.py | 143 +++++++ .../script/repo_file_index_local.py | 275 ++++++++++++++ .../script/repo_filename_index_local.py | 271 ++++++++++++++ semantic_search/semantic_search.py | 50 +++ semantic_search/semantic_search_settings.py | 22 ++ semantic_search/utils/__init__.py | 88 +++++ semantic_search/utils/commit_differ.py | 105 ++++++ semantic_search/utils/constants.py | 5 + semantic_search/utils/sea_embedding_api.py | 44 +++ semantic_search/utils/seasearch_api.py | 136 +++++++ semantic_search/utils/text_splitter.py | 282 ++++++++++++++ 31 files changed, 3490 insertions(+), 2 deletions(-) create mode 100644 semantic_search/config.py create mode 100644 semantic_search/db.py create mode 100644 semantic_search/index_store/extract.py create mode 100644 semantic_search/index_store/index_manager.py create mode 100644 semantic_search/index_store/models.py create mode 100644 semantic_search/index_store/repo_file_index.py create mode 100644 semantic_search/index_store/repo_file_name_index.py create mode 100644 semantic_search/index_store/repo_status_index.py create mode 100644 semantic_search/index_store/utils.py create mode 100644 semantic_search/index_task/filename_index_updater.py create mode 100644 semantic_search/index_task/index_task_manager.py create mode 100644 semantic_search/script/portalocker/__init__.py create mode 100644 semantic_search/script/portalocker/portalocker.py create mode 100644 semantic_search/script/portalocker/utils.py create mode 100644 semantic_search/script/repo_file_index_local.py create mode 100644 semantic_search/script/repo_filename_index_local.py create mode 100644 semantic_search/semantic_search.py create mode 100644 semantic_search/semantic_search_settings.py create mode 100644 semantic_search/utils/__init__.py create mode 100644 semantic_search/utils/commit_differ.py create mode 100644 semantic_search/utils/constants.py create mode 100644 semantic_search/utils/sea_embedding_api.py create mode 100644 semantic_search/utils/seasearch_api.py create mode 100644 semantic_search/utils/text_splitter.py diff --git a/app/app.py b/app/app.py index 7e45b3f6..ec1b1898 100644 --- a/app/app.py +++ b/app/app.py @@ -3,11 +3,12 @@ VirusScanner, Statistics, CountUserActivity, CountTrafficInfo, ContentScanner,\ WorkWinxinNoticeSender, FileUpdatesSender, RepoOldFileAutoDelScanner,\ DeletedFilesCountCleaner +from seafevents.semantic_search.semantic_search import SemanticSearch from seafevents.repo_metadata.index_master import RepoMetadataIndexMaster from seafevents.repo_metadata.index_worker import RepoMetadataIndexWorker from seafevents.seafevent_server.seafevent_server import SeafEventServer -from seafevents.app.config import ENABLE_METADATA_MANAGEMENT +from seafevents.app.config import ENABLE_METADATA_MANAGEMENT, ENABLE_SEAFILE_AI class App(object): @@ -38,6 +39,8 @@ def __init__(self, config, ccnet_config, seafile_config, if ENABLE_METADATA_MANAGEMENT: self._index_master = RepoMetadataIndexMaster(config) self._index_worker = RepoMetadataIndexWorker(config) + if ENABLE_SEAFILE_AI: + self._sem_app = SemanticSearch() def serve_forever(self): if self._fg_tasks_enabled: diff --git a/app/config.py b/app/config.py index b7d48a97..4fccf1fe 100644 --- a/app/config.py +++ b/app/config.py @@ -24,6 +24,7 @@ METADATA_SERVER_URL = getattr(seahub_settings, 'METADATA_SERVER_URL', '') ENABLE_METADATA_MANAGEMENT = getattr(seahub_settings, 'ENABLE_METADATA_MANAGEMENT', False) METADATA_FILE_TYPES = getattr(seahub_settings, 'METADATA_FILE_TYPES', {}) + ENABLE_SEAFILE_AI = getattr(seahub_settings, 'ENABLE_SEAFILE_AI', False) except ImportError: logger.critical("Can not import seahub settings.") raise RuntimeError("Can not import seahub settings.") diff --git a/mysql.sql b/mysql.sql index 4bbfdb00..bd17479b 100644 --- a/mysql.sql +++ b/mysql.sql @@ -213,3 +213,12 @@ CREATE TABLE IF NOT EXISTS `GroupIdLDAPUuidPair` ( UNIQUE KEY `group_id` (`group_id`), UNIQUE KEY `group_uuid` (`group_uuid`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8; + +CREATE TABLE IF NOT EXISTS `index_repo` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `repo_id` varchar(36) NOT NULL, + `created_at` datetime(6) NOT NULL, + `updated` datetime(6), + PRIMARY KEY (`id`) USING BTREE, + UNIQUE KEY `repo_id`(`repo_id`) +) ENGINE = InnoDB DEFAULT CHARACTER SET = utf8; diff --git a/repo_data/db.py b/repo_data/db.py index 2c766c4d..de19cecc 100644 --- a/repo_data/db.py +++ b/repo_data/db.py @@ -29,7 +29,7 @@ def create_engine_from_conf(config_file): if seaf_conf.has_option('database', 'host'): db_server = seaf_conf.get('database', 'host') if seaf_conf.has_option('database', 'port'): - db_port =seaf_conf.getint('database', 'port') + db_port = seaf_conf.getint('database', 'port') db_username = seaf_conf.get('database', 'user') db_passwd = seaf_conf.get('database', 'password') db_name = seaf_conf.get('database', 'db_name') diff --git a/requirements.txt b/requirements.txt index c3427434..f3588ba9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,6 @@ pymysql gevent==24.2.* Flask==2.2.* apscheduler +unstructured[docx,pptx] +transformers +ndjson diff --git a/seafevent_server/request_handler.py b/seafevent_server/request_handler.py index 344a6d47..15a81df7 100644 --- a/seafevent_server/request_handler.py +++ b/seafevent_server/request_handler.py @@ -1,11 +1,14 @@ import jwt import logging +import json from flask import Flask, request, make_response from seafevents.app.config import SEAHUB_SECRET_KEY from seafevents.seafevent_server.task_manager import task_manager from seafevents.seafevent_server.export_task_manager import event_export_task_manager +from seafevents.semantic_search.index_task.index_task_manager import index_task_manager +from seafevents.semantic_search.semantic_search import sem_app app = Flask(__name__) logger = logging.getLogger(__name__) @@ -96,3 +99,187 @@ def query_status(): return make_response((error, 500)) return make_response(({'is_finished': is_finished}, 200)) + +@app.route('/library-sdoc-indexes', methods=['POST']) +def library_sdoc_indexes(): + is_valid = check_auth_token(request) + if not is_valid: + return {'error_msg': 'Permission denied'}, 403 + + try: + data = json.loads(request.data) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Bad request.'}, 400 + + repo_id = data.get('repo_id') + + if not repo_id: + return {'error_msg': 'repo_id invalid.'}, 400 + + commit_id = sem_app.repo_data.get_repo_head_commit(repo_id) + + if not commit_id: + return {'error_msg': 'repo invalid.'}, 400 + + try: + is_exist = sem_app.repo_file_index.check_index(repo_id) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Internet server error.'}, 500 + + if is_exist: + return {'error_msg': 'Index exists.'}, 400 + + task = index_task_manager.get_pending_or_running_task(repo_id) + + if task: + return {'task_id': task.id}, 200 + + try: + sem_app.index_manager.create_index_repo_db(repo_id) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Internet server error.'}, 500 + + task_id = index_task_manager.add_library_sdoc_index_task(repo_id, commit_id) + + return {'task_id': task_id}, 200 + + +@app.route('/search', methods=['POST']) +def search(): + is_valid = check_auth_token(request) + if not is_valid: + return {'error_msg': 'Permission denied'}, 403 + + try: + data = json.loads(request.data) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Bad request.'}, 400 + + query = data.get('query').strip() + repos = data.get('repos') + suffixes = data.get('suffixes') + search_filename_only = data.get('search_filename_only') + + if not query: + return {'error_msg': 'query invalid.'}, 400 + + if not repos: + return {'error_msg': 'repos invalid.'}, 400 + + try: + count = int(data.get('count')) + except: + count = 20 + + if search_filename_only: + results = index_task_manager.keyword_search(query, repos, count, suffixes) + else: + results = index_task_manager.hybrid_search(query, repos[0], count) + + return {'results': results}, 200 + + +@app.route('/library-sdoc-index', methods=['PUT', 'DELETE']) +def library_sdoc_index(): + is_valid = check_auth_token(request) + if not is_valid: + return {'error_msg': 'Permission denied'}, 403 + + try: + data = json.loads(request.data) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Bad request.'}, 400 + + repo_id = data.get('repo_id') + + if not repo_id: + return {'error_msg': 'repo_id invalid'}, 400 + + try: + index_repo = sem_app.index_manager.get_index_repo_by_repo_id(repo_id) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Internet server error.'}, 500 + + if request.method == 'DELETE': + if not index_repo: + return {'success': True}, 200 + + task = index_task_manager.get_pending_or_running_task(repo_id) + + if task: + return {'error_msg': 'library sdoc index is running'}, 400 + + try: + sem_app.index_manager.delete_library_sdoc_index_by_repo_id(repo_id, sem_app.repo_file_index, sem_app.repo_status_index) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Internet server error.'}, 500 + + return {'success': True}, 200 + + elif request.method == 'PUT': + commit_id = sem_app.repo_data.get_repo_head_commit(repo_id) + + if not commit_id: + return {'error_msg': 'repo invalid.'}, 400 + + task = index_task_manager.get_pending_or_running_task(repo_id) + + if task: + return {'task_id': task.id}, 200 + + + try: + task_id = index_task_manager.add_update_a_library_sdoc_index_task(repo_id, commit_id) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Internet server error.'}, 500 + + return {'task_id': task_id}, 200 + + +@app.route('/task-status', methods=['GET']) +def query_task_status(): + is_valid = check_auth_token(request) + if not is_valid: + return {'error_msg': 'Permission denied'}, 403 + + task_id = request.args.get('task_id') + if not task_id: + return {'error_msg': 'task_id invalid'}, 400 + + task = index_task_manager.query_task(task_id) + if not task: + return {'error_msg': 'Task not found'}, 404 + + return {'is_finished': task.is_finished()} + + +@app.route('/library-index-state', methods=['GET']) +def query_library_index_state(): + is_valid = check_auth_token(request) + if not is_valid: + return {'error_msg': 'Permission denied'}, 403 + + repo_id = request.args.get('repo_id') + if not repo_id: + return {'error_msg': 'repo_id invalid'}, 400 + + try: + is_exist = sem_app.index_manager.get_index_repo_by_repo_id(repo_id) + except Exception as e: + logger.exception(e) + return {'error_msg': 'Internet server error.'}, 500 + + if not is_exist: + return {'state': 'uncreated', 'task_id': ''} + + task = index_task_manager.get_pending_or_running_task(repo_id) + + return task and {'state': 'running', 'task_id': task.id} or {'state': 'finished', 'task_id': ''} diff --git a/seafevent_server/seafevent_server.py b/seafevent_server/seafevent_server.py index 2531bb3a..d176a3f7 100644 --- a/seafevent_server/seafevent_server.py +++ b/seafevent_server/seafevent_server.py @@ -4,6 +4,10 @@ from seafevents.seafevent_server.request_handler import app as application from seafevents.seafevent_server.task_manager import task_manager from seafevents.seafevent_server.export_task_manager import event_export_task_manager +from seafevents.semantic_search.semantic_search import sem_app +from seafevents.semantic_search.index_task.index_task_manager import index_task_manager +from seafevents.semantic_search.index_task.filename_index_updater import repo_filename_index_updater +from seafevents.app.config import ENABLE_SEAFILE_AI class SeafEventServer(Thread): @@ -17,6 +21,16 @@ def __init__(self, app, config): task_manager.run() event_export_task_manager.run() + + if ENABLE_SEAFILE_AI: + # semantic search index task + sem_app.init() + index_task_manager.init(sem_app) + repo_filename_index_updater.init(sem_app) + + index_task_manager.start() + repo_filename_index_updater.start() + self._server = WSGIServer((self._host, int(self._port)), application) def _parse_config(self, config): diff --git a/semantic_search/config.py b/semantic_search/config.py new file mode 100644 index 00000000..ad3a33d2 --- /dev/null +++ b/semantic_search/config.py @@ -0,0 +1,71 @@ +import os +import logging + +from seafevents.app.config import get_config + +logger = logging.getLogger(__name__) + + +APP_NAME = 'semantic-search' + +# sections +## indexManager worker count +INDEX_MANAGER_WORKERS = 2 +INDEX_TASK_EXPIRE_TIME = 30 * 60 + +RETRIEVAL_NUM = 20 + +# embedding dimension +DIMENSION = 768 + +MODEL_VOCAB_PATH = '' +FILE_SENTENCE_LIMIT = 1000 + +THRESHOLD = 0.01 + +## seasearch +SEASEARCH_SERVER = 'http://127.0.0.1:4080' +SEASEARCH_TOKEN = '' +VECTOR_M = 256 +SHARD_NUM = 1 + +## sea-embedding +SEA_EMBEDDING_SERVER = '' +SEA_EMBEDDING_KEY = '' + + +# repo file index support file types +SUPPORT_INDEX_FILE_TYPES = [ + '.sdoc', + '.md', + '.markdown', + '.doc', + '.docx', + '.ppt', + '.pptx', + '.pdf', +] + + +CONF_DIR = '/opt/seafile/conf/' + +try: + import seahub.settings as seahub_settings + SEA_EMBEDDING_SERVER = getattr(seahub_settings, 'SEA_EMBEDDING_SERVER', '') + SEA_EMBEDDING_KEY = getattr(seahub_settings, 'SEA_EMBEDDING_KEY', '') + SEASEARCH_SERVER = getattr(seahub_settings, 'SEASEARCH_SERVER', '') + SEASEARCH_TOKEN = getattr(seahub_settings, 'SEASEARCH_TOKEN', '') + MODEL_VOCAB_PATH = getattr(seahub_settings, 'MODEL_VOCAB_PATH', '') + MODEL_CACHE_DIR = getattr(seahub_settings, 'MODEL_CACHE_DIR', '') + INDEX_STORAGE_PATH = getattr(seahub_settings, 'INDEX_STORAGE_PATH', '') +except ImportError: + logger.critical("Can not import seahub settings.") + raise RuntimeError("Can not import seahub settings.") + + +try: + + if os.path.exists('/data/dev/seafevents/semantic_search/semantic_search_settings.py'): + from seafevents.semantic_search.semantic_search_settings import * +except: + pass diff --git a/semantic_search/db.py b/semantic_search/db.py new file mode 100644 index 00000000..01b39f6f --- /dev/null +++ b/semantic_search/db.py @@ -0,0 +1,81 @@ +import logging +import configparser + +from urllib.parse import quote_plus + +from sqlalchemy import create_engine +from sqlalchemy.event import contains as has_event_listener, listen as add_event_listener +from sqlalchemy.exc import DisconnectionError +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import Pool + + +# base class of model classes in events.models and stats.models +class Base(DeclarativeBase): + pass + + +logger = logging.getLogger('seafevents') + + +def create_engine_from_conf(config_file): + seaf_conf = configparser.ConfigParser() + seaf_conf.read(config_file) + backend = seaf_conf.get('DATABASE', 'type') + + if backend == 'mysql': + db_server = 'localhost' + db_port = 3306 + + if seaf_conf.has_option('DATABASE', 'host'): + db_server = seaf_conf.get('DATABASE', 'host') + if seaf_conf.has_option('DATABASE', 'port'): + db_port = seaf_conf.getint('DATABASE', 'port') + db_username = seaf_conf.get('DATABASE', 'username') + db_passwd = seaf_conf.get('DATABASE', 'password') + db_name = seaf_conf.get('DATABASE', 'name') + db_url = "mysql+pymysql://%s:%s@%s:%s/%s?charset=utf8" % \ + (db_username, quote_plus(db_passwd), + db_server, db_port, db_name) + else: + logger.critical("Unknown Database backend: %s" % backend) + raise RuntimeError("Unknown Database backend: %s" % backend) + + kwargs = dict(pool_recycle=300, echo=False, echo_pool=False) + + engine = create_engine(db_url, **kwargs) + if not has_event_listener(Pool, 'checkout', ping_connection): + # We use has_event_listener to double check in case we call create_engine + # multipe times in the same process. + add_event_listener(Pool, 'checkout', ping_connection) + + return engine + +def init_db_session_class(config_file): + """Configure Session class for mysql according to the config file.""" + try: + engine = create_engine_from_conf(config_file) + except (configparser.NoOptionError, configparser.NoSectionError) as e: + logger.error(e) + raise RuntimeError("invalid config file %s", config_file) + + Session = sessionmaker(bind=engine) + return Session + +# This is used to fix the problem of "MySQL has gone away" that happens when +# mysql server is restarted or the pooled connections are closed by the mysql +# server beacause being idle for too long. +# +# See http://stackoverflow.com/a/17791117/1467959 +def ping_connection(dbapi_connection, connection_record, connection_proxy): # pylint: disable=unused-argument + cursor = dbapi_connection.cursor() + try: + cursor.execute("SELECT 1") + cursor.close() + except: + logger.info('fail to ping database server, disposing all cached connections') + connection_proxy._pool.dispose() # pylint: disable=protected-access + + # Raise DisconnectionError so the pool would create a new connection + raise DisconnectionError() diff --git a/semantic_search/index_store/extract.py b/semantic_search/index_store/extract.py new file mode 100644 index 00000000..f023ea71 --- /dev/null +++ b/semantic_search/index_store/extract.py @@ -0,0 +1,207 @@ +# coding: UTF-8 +import os +import json +import logging +from io import BytesIO + +from unstructured.partition.pptx import partition_pptx +from unstructured.partition.doc import partition_doc +from unstructured.partition.docx import partition_docx +from unstructured.partition.ppt import partition_ppt +from unstructured.staging.base import convert_to_text + +from seafevents.semantic_search.utils.constants import ZERO_OBJ_ID +from seafevents.semantic_search.utils.text_splitter import \ + MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter, tokenizer_length + +from seafobj import fs_mgr + +logger = logging.getLogger(__name__) + +OFFICE_FILE_SIZE_LIMIT = 1024 * 1024 * 10 +TEXT_FILE_SIZE_LIMIT = 1024 * 1024 + +text_suffixes = [ + 'sdoc', + 'md', + 'markdown' +] + +office_suffixes = [ + 'doc', + 'docx', + 'ppt', + 'pptx', + 'pdf', +] + + +def parse_sdoc_to_spilt_sentences(content): + content = content.decode() + content = json.loads(content) + sentences = [] + for children in content.get('children', []): + if children.get('type') == 'code_block': + continue + + combined_text_list = parse_children_text(children, []) + + if not combined_text_list: + continue + + sentence = '。'.join(combined_text_list) + sentences.append(sentence) + return sentences + + +def parse_children_text(children, text_list=[]): + text = children.get('text', '') + if text and text.strip(): + text_list.append(text.strip()) + + children_list = children.get('children') + if children_list: + for children in children_list: + parse_children_text(children, text_list) + + return text_list + + +def parse_md_to_spilt_sentences(content): + content = content.decode() + chunk_size = 100 + headers_to_split_on = ["#", "##", "###", "####", "#####", "######"] + text_splitter = MarkdownHeaderTextSplitter(headers_to_split_on) + md_header_splits = text_splitter.split_text(content) + + recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, length_function=tokenizer_length) + + sentences = [] + for data in md_header_splits: + text = data.get('content', '') + if len(text) > chunk_size: + split_texts = recursive_splitter.split_text(text) + sentences.extend(split_texts) + else: + sentences.append(text) + + return sentences + + +def recursive_split_text_to_sentences(text, chunk_size=100): + recursive_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, length_function=tokenizer_length) + if len(text) > chunk_size: + sentences = recursive_splitter.split_text(text) + else: + sentences = [text] + + return sentences + + +def parse_doc_to_split_sentences(content): + doc_elements = partition_doc(file=BytesIO(content)) + doc_text = convert_to_text(doc_elements) + return recursive_split_text_to_sentences(doc_text) + + +def parse_docx_to_split_sentences(content): + doc_elements = partition_docx(file=BytesIO(content)) + doc_text = convert_to_text(doc_elements) + return recursive_split_text_to_sentences(doc_text) + + +def parse_ppt_to_split_sentences(content): + doc_elements = partition_ppt(file=BytesIO(content)) + doc_text = convert_to_text(doc_elements) + return recursive_split_text_to_sentences(doc_text) + + +def parse_pptx_to_split_sentences(content): + doc_elements = partition_pptx(file=BytesIO(content)) + doc_text = convert_to_text(doc_elements) + return recursive_split_text_to_sentences(doc_text) + + +def parse_pdf_to_split_sentences(content): + import PyPDF2 + + sentences = [] + pdf_reader = PyPDF2.PdfReader(BytesIO(content)) + for page in pdf_reader.pages: + text = page.extract_text() + split_sentences = recursive_split_text_to_sentences(text) + sentences.extend(split_sentences) + + return sentences + + +EXTRACT_TEXT_FUNCS = { + 'sdoc': parse_sdoc_to_spilt_sentences, + 'md': parse_md_to_spilt_sentences, + 'doc': parse_doc_to_split_sentences, + 'docx': parse_docx_to_split_sentences, + 'ppt': parse_ppt_to_split_sentences, + 'pptx': parse_pptx_to_split_sentences, + 'pdf': parse_pdf_to_split_sentences, +} + + +def get_file_suffix(path): + try: + name = os.path.basename(path) + suffix = os.path.splitext(name)[1][1:] + if suffix: + return suffix.lower() + return None + except: + return None + + +class Extractor(object): + def __init__(self, func, file_size_limit=-1): + self.func = func + self.file_size_limit = file_size_limit + + def extract(self, repo_id, version, obj_id, path): + if obj_id == ZERO_OBJ_ID: + return None + + f = fs_mgr.load_seafile(repo_id, version, obj_id) + if self.file_size_limit < f.size: + logger.warning("file %s size exceeds limit", path) + return None + content = f.get_content(limit=self.file_size_limit) + if not content: + # An empty file + return None + + try: + logger.info('extracting %s %s...', repo_id, path) + sentences = self.func(content) + logger.info('successfully extracted %s', path) + except Exception as e: + logger.warning('failed to extract %s: %s', path, e) + return None + + return sentences + + +class ExtractorFactory(object): + @classmethod + def get_extractor(cls, filename): + + suffix = get_file_suffix(filename) + func = EXTRACT_TEXT_FUNCS.get(suffix, None) + if not func: + return None + return Extractor(func, cls.get_file_size_limit(filename)) + + @classmethod + def get_file_size_limit(cls, filename): + suffix = get_file_suffix(filename) + + if suffix in text_suffixes: + return TEXT_FILE_SIZE_LIMIT + elif suffix in office_suffixes: + return OFFICE_FILE_SIZE_LIMIT + return -1 diff --git a/semantic_search/index_store/index_manager.py b/semantic_search/index_store/index_manager.py new file mode 100644 index 00000000..f9350459 --- /dev/null +++ b/semantic_search/index_store/index_manager.py @@ -0,0 +1,166 @@ +import logging +import time +import os +from datetime import datetime + +from sqlalchemy.sql import text + +from seafevents.semantic_search import config +from seafevents.semantic_search.utils.constants import ZERO_OBJ_ID, REPO_FILENAME_INDEX_PREFIX +from seafevents.semantic_search.db import init_db_session_class +from seafevents.semantic_search.index_store.models import IndexRepo +from seafevents.semantic_search.index_store.utils import rank_fusion, filter_hybrid_searched_files + +logger = logging.getLogger(__name__) + + +class IndexManager(): + def __init__(self): + self.evtconf = os.environ['EVENTS_CONFIG_FILE'] + self._db_session_class = init_db_session_class(self.evtconf) + + def create_index_repo_db(self, repo_id): + with self._db_session_class() as db_session: + index_repo = IndexRepo(repo_id, datetime.now(), datetime.now()) + db_session.add(index_repo) + db_session.commit() + + def delete_index_repo_db(self, repo_id): + with self._db_session_class() as db_session: + db_session.query(IndexRepo).filter(IndexRepo.repo_id == repo_id).delete() + db_session.commit() + + def update_index_repo_db(self, repo_id): + with self._db_session_class() as db_session: + index_repo = db_session.query(IndexRepo). \ + filter(IndexRepo.repo_id == repo_id) + index_repo.update({"updated": datetime.now()}) + db_session.commit() + + def get_index_repo_by_repo_id(self, repo_id): + with self._db_session_class() as db_session: + return db_session.query(IndexRepo).filter(IndexRepo.repo_id == repo_id).first() + + def list_index_repos(self): + with self._db_session_class() as db_session: + sql = """ + SELECT `repo_id` FROM index_repo + """ + + index_repos = db_session.execute(text(sql)) + return index_repos + + def get_index_repos_by_size(self, start, size): + with self._db_session_class() as db_session: + sql = """ + SELECT `repo_id` + FROM index_repo LIMIT :start, :size + """ + + index_repos = db_session.execute(text(sql), { + 'start': start, + 'size': size, + }) + return index_repos + + def create_library_sdoc_index(self, repo_id, embedding_api, repo_file_index, repo_status_index, commit_id): + repo_status_index.begin_update_repo(repo_id, ZERO_OBJ_ID, commit_id) + repo_file_index.create_index(repo_id) + repo_file_index.add(repo_id, ZERO_OBJ_ID, commit_id, embedding_api) + repo_status_index.finish_update_repo(repo_id, commit_id) + + logger.info('library: %s, save library file to SeaSearch success', repo_id) + + def search_children_in_library(self, query, repo, embedding_api, repo_file_index, count=20): + return repo_file_index.search_files(repo, config.RETRIEVAL_NUM, embedding_api, query)[:count] + + def update_library_sdoc_index(self, repo_id, embedding_api, repo_file_index, repo_status_index, new_commit_id): + try: + repo_status = repo_status_index.get_repo_status_by_id(repo_id) + + from_commit = repo_status.from_commit + to_commit = repo_status.to_commit + + if new_commit_id == from_commit: + return + + commit_id = from_commit + if repo_status.need_recovery(): + logger.warning('%s: repo file index inrecovery', repo_id) + + is_exist = repo_file_index.check_index(repo_id) + if not is_exist: + repo_file_index.create_index(repo_id) + + repo_file_index.update(repo_id, from_commit, to_commit, embedding_api) + + # time sleep for SeaSearch save data + time.sleep(1) + + commit_id = to_commit + repo_status_index.begin_update_repo(repo_id, commit_id, new_commit_id) + repo_file_index.update(repo_id, commit_id, new_commit_id, embedding_api) + repo_status_index.finish_update_repo(repo_id, new_commit_id) + + self.update_index_repo_db(repo_id) + + logger.info('repo: %s, update repo file index success', repo_id) + + except Exception as e: + logger.exception('repo_id: %s, update repo file index error: %s.', repo_id, e) + + def delete_library_sdoc_index_by_repo_id(self, repo_id, repo_file_index, repo_status_index): + # first delete repo_file_index + repo_file_index.delete_index_by_index_name(repo_id) + repo_status_index.delete_documents_by_repo(repo_id) + self.delete_index_repo_db(repo_id) + + def keyword_search(self, query, repos, repo_filename_index, count, suffixes=None): + return repo_filename_index.search_files(repos, query, 0, count, suffixes) + + def update_library_filename_index(self, repo_id, commit_id, repo_filename_index, repo_status_filename_index): + try: + new_commit_id = commit_id + index_name = REPO_FILENAME_INDEX_PREFIX + repo_id + + repo_filename_index.create_index_if_missing(index_name) + + repo_status = repo_status_filename_index.get_repo_status_by_id(repo_id) + from_commit = repo_status.from_commit + to_commit = repo_status.to_commit + + if new_commit_id == from_commit: + return + + if not from_commit: + commit_id = ZERO_OBJ_ID + else: + commit_id = from_commit + + if repo_status.need_recovery(): + logger.warning('%s: repo filename index inrecovery', repo_id) + repo_filename_index.update(index_name, repo_id, commit_id, to_commit) + commit_id = to_commit + time.sleep(1) + + repo_status_filename_index.begin_update_repo(repo_id, commit_id, new_commit_id) + repo_filename_index.update(index_name, repo_id, commit_id, new_commit_id) + repo_status_filename_index.finish_update_repo(repo_id, new_commit_id) + + logger.info('repo: %s, update repo filename index success', repo_id) + + except Exception as e: + logger.exception('repo_id: %s, update repo filename index error: %s.', repo_id, e) + + def delete_repo_filename_index(self, repo_id, repo_filename_index, repo_status_filename_index): + # first delete repo_file_index + repo_filename_index_name = REPO_FILENAME_INDEX_PREFIX + repo_id + repo_filename_index.delete_index_by_index_name(repo_filename_index_name) + repo_status_filename_index.delete_documents_by_repo(repo_id) + + def hybrid_search(self, query, repo, repo_filename_index, embedding_api, repo_file_index, count): + keyword_files = self.keyword_search(query, [repo], repo_filename_index, count) + similar_files = self.search_children_in_library(query, repo, embedding_api, repo_file_index, count) + fused_files = rank_fusion([keyword_files, similar_files]) + + return filter_hybrid_searched_files(fused_files)[:count] diff --git a/semantic_search/index_store/models.py b/semantic_search/index_store/models.py new file mode 100644 index 00000000..4142e59f --- /dev/null +++ b/semantic_search/index_store/models.py @@ -0,0 +1,25 @@ +from sqlalchemy import Column, Integer, String, DateTime +from seafevents.semantic_search.db import Base + + +class IndexRepo(Base): + __tablename__ = 'index_repo' + + id = Column(Integer, primary_key=True, autoincrement=True) + repo_id = Column(String(length=36), nullable=False) + created_at = Column(DateTime, nullable=False) + updated = Column(DateTime, nullable=False) + + def __init__(self, repo_id, created_at, updated=None): + self.repo_id = repo_id + self.created_at = created_at + self.updated = updated + + def to_dict(self): + res = { + 'id': self.id, + 'dtable_uuid': self.repo_id, + 'created_at': self.created_at.isoformat(), + 'updated': self.updated.isoformat() if self.updated else None, + } + return res diff --git a/semantic_search/index_store/repo_file_index.py b/semantic_search/index_store/repo_file_index.py new file mode 100644 index 00000000..30d15541 --- /dev/null +++ b/semantic_search/index_store/repo_file_index.py @@ -0,0 +1,216 @@ +import os +import logging + +from seafevents.semantic_search import config +from seafevents.semantic_search.index_store.utils import parse_file_to_sentences, bulk_add_sentences_to_index +from seafevents.semantic_search.utils import get_library_diff_files, is_sys_dir_or_file + +logger = logging.getLogger(__name__) + + +SEASEARCH_BULK_OPETATE_LIMIT = 1000 +SEASEARCH_QUERY_PATH_DOC_STEP = 20 + + +class RepoFileIndex(object): + """ + index name is repo id + """ + mapping = { + "properties": { + "vec": { + "type": "vector", + "dims": config.DIMENSION, + "vec_index_type": "ivf_pq", + "nbits": 4, + "m": config.VECTOR_M + }, + "path": { + "type": "keyword" + }, + 'content': { + 'type': 'text' + } + } + } + + shard_num = config.SHARD_NUM + + def __init__(self, seasearch_api): + self.seasearch_api = seasearch_api + + def create_index(self, index_name): + data = { + 'shard_num': self.shard_num, + 'mappings': self.mapping, + } + self.seasearch_api.create_index(index_name, data) + + def check_index(self, index_name): + return self.seasearch_api.check_index_mapping(index_name).get('is_exist') + + def search_files(self, repo, k, embedding_api, query): + repo_id = repo[0] + origin_repo_id = repo[1] + origin_path = repo[2] + + if origin_repo_id: + repo_id = origin_repo_id + vector = embedding_api.embeddings(query)['data'][0]['embedding'] + data = { + "query_field": "vec", + "k": k, + "return_fields": ["path", "content"], + "_source": False, + "vector": vector + } + + result = self.seasearch_api.vector_search(repo_id, data) + total = result.get('hits', {}).get('total', {}).get('value', 0) + if result.get('error'): + logger.info('search in repo_file_index error: %s .', result.get('error')) + return [] + + hits = result['hits']['hits'] + if not hits: + return [] + searched_result = {} + for hit in hits: + score = hit['_score'] + _id = hit['_id'] + path = hit['fields']['path'][0] + content = hit['fields']['content'][0] + + if origin_path and not path.startswith(origin_path): + continue + + if score < config.THRESHOLD: + continue + + if searched_result.get(path): + pre_score = searched_result[path]['max_score'] + searched_result[path]['score'] = score + pre_score + continue + filename = os.path.basename(path) + searched_result[path] = {'repo_id': repo_id, + 'fullpath': path, + 'name': filename, + 'is_dir': False, + 'score': score, + 'max_score': score, + 'content': content, + '_id': _id + } + + return list(searched_result.values()) + + def delete_index_by_index_name(self, index_name): + self.seasearch_api.delete_index_by_name(index_name) + + def add(self, index_name, old_commit_id, new_commit_id, embedding_api): + self.update(index_name, old_commit_id, new_commit_id, embedding_api) + + def update(self, index_name, old_commit_id, new_commit_id, embedding_api): + """ + old_commit_id is ZERO_OBJ_ID that means create repo file index + """ + added_files, deleted_files, modified_files, _, deleted_dirs = get_library_diff_files(index_name, old_commit_id, new_commit_id) + + need_deleted_files = deleted_files + modified_files + self.delete_files(index_name, need_deleted_files) + + self.delete_files_by_deleted_dirs(index_name, deleted_dirs) + + need_added_files = added_files + modified_files + # deleting files is to prevent duplicate insertions + self.delete_files(index_name, added_files) + self.add_files(index_name, need_added_files, embedding_api, new_commit_id) + + def query_data_by_paths(self, index_name, path_list, start, size): + dsl = { + "query": { + "terms": { + "path": path_list + } + }, + "from": start, + "size": size, + "_source": False, + "sort": ["-@timestamp"], # sort is for getting data ordered + } + hits, total = self.normal_search(index_name, dsl) + return hits, total + + def query_data_by_dir(self, index_name, directory, start, size): + dsl = { + "query": { + "bool": { + "must": [ + {"prefix": {"path": directory}} + ] + } + }, + "from": start, + "size": size, + "_source": False, + "sort": ["-@timestamp"], # sort is for getting data ordered + } + + hits, total = self.normal_search(index_name, dsl) + return hits, total + + def normal_search(self, index_name, dsl): + doc_item = self.seasearch_api.normal_search(index_name, dsl) + total = doc_item['hits']['total']['value'] + + return doc_item['hits']['hits'], total + + def delete_files(self, index_name, files): + step = SEASEARCH_QUERY_PATH_DOC_STEP + for pos in range(0, len(files), step): + paths = [file[0] for file in files[pos: pos + step] if not is_sys_dir_or_file(file[0])] + per_size = SEASEARCH_BULK_OPETATE_LIMIT + start = 0 + delete_params = [] + while True: + hits, total = self.query_data_by_paths(index_name, paths, start, per_size) + for hit in hits: + _id = hit['_id'] + delete_params.append({'delete': {'_id': _id, '_index': index_name}}) + + if delete_params: + self.seasearch_api.bulk(index_name, delete_params) + if len(hits) < per_size: + break + + def delete_files_by_deleted_dirs(self, index_name, dirs): + for directory in dirs: + if is_sys_dir_or_file(directory): + continue + per_size = SEASEARCH_BULK_OPETATE_LIMIT + start = 0 + delete_params = [] + while True: + hits, total = self.query_data_by_dir(index_name, directory, start, per_size) + for hit in hits: + _id = hit['_id'] + delete_params.append({'delete': {'_id': _id, '_index': index_name}}) + + if delete_params: + self.seasearch_api.bulk(index_name, delete_params) + if len(hits) < per_size: + break + + def add_files(self, index_name, files, embedding_api, commit_id): + for file_info in files: + path = file_info[0] + if is_sys_dir_or_file(path): + continue + self.add_file(index_name, file_info, commit_id, embedding_api, path) + logger.info('add file: %s , to index: %s .', path, index_name) + + def add_file(self, index_name, file_info, commit_id, embedding_api, path): + sentences = parse_file_to_sentences(index_name, file_info, commit_id) + sentences = sentences[0: config.FILE_SENTENCE_LIMIT] + limit = int(SEASEARCH_BULK_OPETATE_LIMIT / 2) + bulk_add_sentences_to_index(self.seasearch_api, embedding_api, index_name, path, sentences, limit) diff --git a/semantic_search/index_store/repo_file_name_index.py b/semantic_search/index_store/repo_file_name_index.py new file mode 100644 index 00000000..cea0e4a7 --- /dev/null +++ b/semantic_search/index_store/repo_file_name_index.py @@ -0,0 +1,351 @@ +import json +import os +import logging + +from seafevents.semantic_search.utils import get_library_diff_files, md5, is_sys_dir_or_file +from seafevents.semantic_search import config +from seafevents.semantic_search.utils.constants import REPO_FILENAME_INDEX_PREFIX + +logger = logging.getLogger(__name__) + +SEASEARCH_BULK_OPETATE_LIMIT = 2000 + + +class RepoFileNameIndex(object): + mapping = { + 'properties': { + 'repo_id': { + 'type': 'keyword', + }, + 'path': { + 'type': 'keyword' + }, + 'filename': { + 'type': 'text', + 'fields': { + 'ngram': { + 'type': 'text', + 'index': True, + 'analyzer': 'seafile_file_name_ngram_analyzer', + }, + }, + }, + 'suffix': { + 'type': 'keyword' + }, + 'is_dir': { + 'type': 'boolean', + } + } + } + + index_settings = { + 'analysis': { + 'analyzer': { + 'seafile_file_name_ngram_analyzer': { + 'type': 'custom', + 'tokenizer': 'seafile_file_name_ngram_tokenizer', + 'filter': [ + 'lowercase', + ], + } + }, + 'tokenizer': { + 'seafile_file_name_ngram_tokenizer': { + 'type': 'ngram', + 'min_gram': 4, + 'max_gram': 4 + } + } + } + } + + shard_num = config.SHARD_NUM + + def __init__(self, seasearch_api, repo_data): + self.seasearch_api = seasearch_api + self.repo_data = repo_data + + def create_index_if_missing(self, index_name): + if not self.seasearch_api.check_index_mapping(index_name).get('is_exist'): + data = { + 'shard_num': self.shard_num, + 'mappings': self.mapping, + 'settings': self.index_settings + } + self.seasearch_api.create_index(index_name, data) + + def check_index(self, index_name): + return self.seasearch_api.check_index_mapping(index_name).get('is_exist') + + def _make_query_searches(self, keyword): + match_query_kwargs = {'minimum_should_match': '-25%'} + + def _make_match_query(field, key_word, **kw): + q = {'query': key_word} + q.update(kw) + return {'match': {field: q}} + + searches = [] + searches.append(_make_match_query('filename', keyword, **match_query_kwargs)) + searches.append({ + 'match': { + 'filename.ngram': { + 'query': keyword, + 'minimum_should_match': '80%', + } + } + }) + return searches + + def _add_path_filter(self, query_map, search_path): + if search_path is None: + return query_map + + if query_map['bool'].get('filter'): + query_map['bool']['filter'].append({'prefix': {'path': search_path}}) + else: + query_map['bool']['filter'] = [{'prefix': {'path': search_path}}] + return query_map + + def _add_suffix_filter(self, query_map, suffixes): + if suffixes: + if not query_map['bool'].get('filter'): + query_map['bool']['filter'] = [] + if isinstance(suffixes, list): + suffixes = [x.lower() for x in suffixes] + query_map['bool']['filter'].append({'terms': {'suffix': suffixes}}) + else: + query_map['bool']['filter'].append({'term': {'suffix': suffixes.lower()}}) + return query_map + + def search_files(self, repos, keyword, start=0, size=10, suffixes=None): + bulk_search_params = [] + for repo in repos: + repo_id = repo[0] + origin_repo_id = repo[1] + origin_path = repo[2] + query_map = {'bool': {'should': [], 'minimum_should_match': 1}} + searches = self._make_query_searches(keyword) + query_map['bool']['should'] = searches + + if origin_repo_id: + repo_id = origin_repo_id + query_map = self._add_path_filter(query_map, origin_path) + query_map = self._add_suffix_filter(query_map, suffixes) + + data = { + 'query': query_map, + 'from': start, + 'size': size, + '_source': ['path', 'repo_id', 'filename', 'is_dir'], + 'sort': ['_score'] + } + index_name = REPO_FILENAME_INDEX_PREFIX + repo_id + index_info = {"index": index_name} + bulk_search_params.append(index_info) + bulk_search_params.append(data) + + logger.debug('search in repo_filename_index params: %s', json.dumps(bulk_search_params)) + + results = self.seasearch_api.m_search(bulk_search_params) + files = [] + + for result in results.get('responses'): + hits = result.get('hits', {}).get('hits', []) + + if not hits: + continue + + for hit in hits: + source = hit.get('_source') + score = hit.get('_score') + _id = hit.get('_id') + r = { + 'repo_id': source['repo_id'], + 'fullpath': source['path'], + 'name': source['filename'], + 'is_dir': source['is_dir'], + 'score': score, + '_id': _id, + } + files.append(r) + files = sorted(files, key=lambda row: row['score'], reverse=True)[:size] + + return files + + @staticmethod + def get_file_suffix(path): + try: + name = os.path.basename(path) + suffix = os.path.splitext(name)[1][1:] + if suffix: + return suffix.lower() + return None + except: + return None + + def add_files(self, index_name, repo_id, files): + bulk_add_params = [] + for file_info in files: + path = file_info[0] + obj_id = file_info[1] + mtime = file_info[2] + size = file_info[3] + + if is_sys_dir_or_file(path): + continue + + suffix = self.get_file_suffix(path) + filename = os.path.basename(path) + if suffix: + filename = filename[:-len(suffix)-1] + + index_info = {'index': {'_index': index_name, '_id': md5(path)}} + doc_info = { + 'repo_id': repo_id, + 'path': path, + 'suffix': suffix, + 'filename': filename, + 'is_dir': False, + } + + bulk_add_params.append(index_info) + bulk_add_params.append(doc_info) + + # bulk add every 2000 params + if len(bulk_add_params) >= SEASEARCH_BULK_OPETATE_LIMIT: + self.seasearch_api.bulk(index_name, bulk_add_params) + bulk_add_params = [] + if bulk_add_params: + self.seasearch_api.bulk(index_name, bulk_add_params) + + def add_dirs(self, index_name, repo_id, dirs): + bulk_add_params = [] + for dir in dirs: + path = dir[0] + obj_id = dir[1] + mtime = dir[2] + size = dir[3] + + if is_sys_dir_or_file(path): + continue + + if path == '/': + repo = self.repo_data.get_repo_name_mtime_size(repo_id) + if not repo: + return + + filename = repo[0]['name'] + else: + filename = os.path.basename(path) + + path = path + '/' if path != '/' else path + index_info = {'index': {'_index': index_name, '_id': md5(path)}} + doc_info = { + 'repo_id': repo_id, + 'path': path, + 'suffix': None, + 'filename': filename, + 'is_dir': True, + } + bulk_add_params.append(index_info) + bulk_add_params.append(doc_info) + + # bulk add every 2000 params + if len(bulk_add_params) >= SEASEARCH_BULK_OPETATE_LIMIT: + self.seasearch_api.bulk(index_name, bulk_add_params) + bulk_add_params = [] + if bulk_add_params: + self.seasearch_api.bulk(index_name, bulk_add_params) + + def delete_files(self, index_name, files): + delete_params = [] + for file in files: + path = file[0] + if is_sys_dir_or_file(path): + continue + delete_params.append({'delete': {'_id': md5(path), '_index': index_name}}) + # bulk add every 2000 params + if len(delete_params) >= SEASEARCH_BULK_OPETATE_LIMIT: + self.seasearch_api.bulk(index_name, delete_params) + delete_params = [] + if delete_params: + self.seasearch_api.bulk(index_name, delete_params) + + def delete_dirs(self, index_name, dirs): + delete_params = [] + for dir in dirs: + path = dir + + if is_sys_dir_or_file(path): + continue + path = path + '/' if path != '/' else path + delete_params.append({'delete': {'_id': md5(path), '_index': index_name}}) + # bulk add every 2000 params + if len(delete_params) >= SEASEARCH_BULK_OPETATE_LIMIT: + self.seasearch_api.bulk(index_name, delete_params) + delete_params = [] + if delete_params: + self.seasearch_api.bulk(index_name, delete_params) + + def query_data_by_dir(self, index_name, directory, start, size): + dsl = { + "query": { + "bool": { + "must": [ + {"prefix": {"path": directory}} + ] + } + }, + "from": start, + "size": size, + "_source": False, + "sort": ["-@timestamp"], # sort is for getting data ordered + } + + hits, total = self.normal_search(index_name, dsl) + return hits, total + + def normal_search(self, index_name, dsl): + doc_item = self.seasearch_api.normal_search(index_name, dsl) + total = doc_item['hits']['total']['value'] + + return doc_item['hits']['hits'], total + + def delete_files_by_deleted_dirs(self, index_name, dirs): + for directory in dirs: + if is_sys_dir_or_file(directory): + continue + per_size = SEASEARCH_BULK_OPETATE_LIMIT + start = 0 + delete_params = [] + while True: + hits, total = self.query_data_by_dir(index_name, directory, start, per_size) + for hit in hits: + _id = hit['_id'] + delete_params.append({'delete': {'_id': _id, '_index': index_name}}) + + if delete_params: + self.seasearch_api.bulk(index_name, delete_params) + if len(hits) < per_size: + break + + def update(self, index_name, repo_id, old_commit_id, new_commit_id): + added_files, deleted_files, modified_files, added_dirs, deleted_dirs = \ + get_library_diff_files(repo_id, old_commit_id, new_commit_id) + + need_deleted_files = deleted_files + self.delete_files(index_name, need_deleted_files) + + self.delete_dirs(index_name, deleted_dirs) + + self.delete_files_by_deleted_dirs(index_name, deleted_dirs) + + need_added_files = added_files + modified_files + self.add_files(index_name, repo_id, need_added_files) + + self.add_dirs(index_name, repo_id, added_dirs) + + def delete_index_by_index_name(self, index_name): + self.seasearch_api.delete_index_by_name(index_name) diff --git a/semantic_search/index_store/repo_status_index.py b/semantic_search/index_store/repo_status_index.py new file mode 100644 index 00000000..adfce0d8 --- /dev/null +++ b/semantic_search/index_store/repo_status_index.py @@ -0,0 +1,158 @@ +from seafevents.semantic_search import config + + +class RepoStatus(object): + def __init__(self, repo_id, from_commit, to_commit): + self.repo_id = repo_id + self.from_commit = from_commit + self.to_commit = to_commit + + def need_recovery(self): + return self.to_commit is not None + + +class RepoStatusIndex(): + """The repo-head index is used to store the status for each repo. + + For each repo: + (1) before update: commit = , updatingto = None + (2) during updating: commit = , updatingto = + (3) after updating: commit = , updatingto = None + + When error occured during updating, the status is left in case (2). So the + next time we update that repo, we can recover the failed process again. + + The elasticsearch document id for each repo in repo_head index is its repo + id. + """ + + mapping = { + 'properties': { + 'repo_id': { + 'type': 'keyword' + }, + 'commit_id': { + 'type': 'keyword' + } + , + 'updatingto': { + 'type': 'keyword' + } + }, + } + + shard_num = config.SHARD_NUM + + def __init__(self, seasearch_api, index_name): + self.index_name = index_name + self.seasearch_api = seasearch_api + self.create_index_if_missing() + + def create_index_if_missing(self): + if not self.seasearch_api.check_index_mapping(self.index_name).get('is_exist'): + data = { + 'mappings': self.mapping, + } + self.seasearch_api.create_index(self.index_name, data) + + def check_repo_status(self, repo_id): + return self.seasearch_api.check_document_by_id(self.index_name, repo_id).get('is_exist') + + def add_repo_status(self, repo_id, commit_id, updatingto): + date = { + 'repo_id': repo_id, + 'commit_id': commit_id, + 'updatingto': updatingto, + } + doc_id = repo_id + self.seasearch_api.create_document_by_id(self.index_name, doc_id, date) + + def begin_update_repo(self, repo_id, old_commit_id, new_commit_id): + self.add_repo_status(repo_id, old_commit_id, new_commit_id) + + def finish_update_repo(self, repo_id, commit_id): + self.add_repo_status(repo_id, commit_id, None) + + def delete_documents_by_repo(self, repo_id): + return self.seasearch_api.delete_document_by_id(self.index_name, repo_id) + + def get_repo_status_by_id(self, repo_id): + doc = self.seasearch_api.get_document_by_id(self.index_name, repo_id) + if doc.get('error'): + return RepoStatus(repo_id, None, None) + commit_id = doc['_source']['commit_id'] + updatingto = doc['_source']['updatingto'] + repo_id = doc['_source']['repo_id'] + + return RepoStatus(repo_id, commit_id, updatingto) + + def update_repo_status_by_id(self, doc_id, data): + self.seasearch_api.update_document_by_id(self.index_name, doc_id, data) + + def get_repo_status_by_time(self, check_time): + per_size = 2000 + start = 0 + repo_head_list = [] + while True: + query_params = { + "query": { + "bool": { + "must": [ + {"range": + {"@timestamp": + { + "lt": check_time + } + } + } + ] + } + }, + "_source": ["commit_id", "updatingto"], + "from": start, + "size": per_size, + "sort": ["-@timestamp"], + } + + repo_heads, total = self._repo_head_search(query_params) + repo_head_list.extend(repo_heads) + start += per_size + if len(repo_heads) < per_size or start == total: + return repo_head_list + + def get_all_repos_from_index(self): + start = 0 + per_size = 2000 + repo_head_list = [] + while True: + repo_heads, total = self.get_repos_from_index_by_size(start, per_size) + repo_head_list.extend(repo_heads) + + start += per_size + if len(repo_heads) < per_size or start == total: + return repo_head_list + + def get_repos_from_index_by_size(self, start, per_size): + query_params = { + 'from': start, + 'size': per_size, + } + + repo_heads, total = self._repo_head_search(query_params) + return repo_heads, total + + def _repo_head_search(self, query_params): + result = self.seasearch_api.normal_search(self.index_name, query_params) + total = result['hits']['total']['value'] + hits = result['hits']['hits'] + repo_heads = [] + + for hit in hits: + repo_id = hit['_id'] + commit_id = hit.get('_source').get('commit_id') + updatingto = hit.get('_source').get('updatingto') + repo_heads.append({'repo_id': repo_id, 'commit_id': commit_id, 'updatingto': updatingto}) + return repo_heads, total + + def delete_index_by_index_name(self): + self.seasearch_api.delete_index_by_name(self.index_name) diff --git a/semantic_search/index_store/utils.py b/semantic_search/index_store/utils.py new file mode 100644 index 00000000..ab7d0d94 --- /dev/null +++ b/semantic_search/index_store/utils.py @@ -0,0 +1,134 @@ +# -*- coding: utf-8 -*- +import os +import logging +from seafevents.semantic_search.index_store.extract import ExtractorFactory +from seafevents.semantic_search.config import SUPPORT_INDEX_FILE_TYPES + +from seafobj import fs_mgr, commit_mgr + +logger = logging.getLogger(__name__) + + +REPO_FILE_INDEX_CONTENT_LIMIT = 200 + + +def get_document_add_params(embedding_api, sentences, index_name, path): + add_params = [] + embeddings = embedding_api.embeddings(sentences) + for item in embeddings['data']: + index_info = {"index": {"_index": index_name}} + vector_info = { + "path": path, + "vec": item['embedding'], + "content": sentences[item['index']][:REPO_FILE_INDEX_CONTENT_LIMIT] + } + add_params.append(index_info) + add_params.append(vector_info) + return add_params + + +def parse_file_to_sentences(index_name, file_info, commit_id): + path = file_info[0] + obj_id = file_info[1] + mtime = file_info[2] + size = file_info[3] + repo_id = index_name + + path_string, ext = os.path.splitext(path) + if ext.lower() not in SUPPORT_INDEX_FILE_TYPES: + return [] + + sentences = [path_string] + if size: + new_commit = commit_mgr.load_commit(repo_id, 0, commit_id) + version = new_commit.get_version() + + extractor = ExtractorFactory.get_extractor(os.path.basename(path)) + file_sentences = extractor.extract(repo_id, version, obj_id, path) if extractor else [] + if file_sentences: + sentences.extend(file_sentences) + + return sentences + + +def rank_fusion(doc_lists, weights=None, c=60): + """ + Args: + doc_lists: A list of rank lists, where each rank list contains unique items. + weights: A list of weights corresponding to the docs. Defaults to equal + weighting for all docs. + c: A constant added to the rank, controlling the balance between the importance + of high-ranked items and the consideration given to lower-ranked items. + Default is 60. + + Returns: + list: The final aggregated list of items sorted by their weighted RRF + scores in descending order. + """ + + if weights is None: + weights = [0.6, 0.4] + if len(doc_lists) != len(weights): + raise ValueError( + "Number of rank lists must be equal to the number of weights." + ) + + # Create a union of all unique documents in the input doc_lists + all_documents = set() + for doc_list in doc_lists: + for doc in doc_list: + all_documents.add(doc.get('_id')) + + # Initialize the RRF score dictionary for each document + rrf_score_dic = {doc: 0.0 for doc in all_documents} + + # Calculate RRF scores for each document + for doc_list, weight in zip(doc_lists, weights): + for rank, doc in enumerate(doc_list, start=1): + rrf_score = weight * (1 / (rank + c)) + rrf_score_dic[doc.get('_id')] += rrf_score + + # Sort documents by their RRF scores in descending order + sorted_documents = sorted( + rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True + ) + + # Map the sorted _id back to the original document + id_to_doc_map = { + doc.get('_id'): doc for doc_list in doc_lists for doc in doc_list + } + sorted_docs = [ + id_to_doc_map[_id] for _id in sorted_documents + ] + + return sorted_docs + + +def filter_hybrid_searched_files(files): + """ + filter duplicate files + """ + + path_set = set() + filtered_files = [] + for file in files: + fullpath = file.get('fullpath') + if fullpath in path_set: + continue + path_set.add(fullpath) + file.pop('_id', None) + file.pop('score', None) + file.pop('max_score', None) + filtered_files.append(file) + return filtered_files + + +def bulk_add_sentences_to_index(seasearch_api, embedding_api, index_name, path, sentences, limit=1000): + step = limit + start = 0 + while True: + if not sentences[start: start + step]: + break + params = get_document_add_params(embedding_api, sentences[start: start + step], index_name, path) + seasearch_api.bulk(index_name, params) + start += step diff --git a/semantic_search/index_task/filename_index_updater.py b/semantic_search/index_task/filename_index_updater.py new file mode 100644 index 00000000..8cb180b0 --- /dev/null +++ b/semantic_search/index_task/filename_index_updater.py @@ -0,0 +1,88 @@ +import logging +from threading import Thread + +from apscheduler.triggers.cron import CronTrigger +from apscheduler.schedulers.gevent import GeventScheduler + + +logger = logging.getLogger(__name__) + + +class RepoFilenameIndexUpdater(): + def __init__(self): + self._repo_status_filename_index = None + self._repo_filename_index = None + self._index_manager = None + self._repo_data = None + + def init(self, app): + self._repo_status_filename_index = app.repo_status_filename_index + self._repo_filename_index = app.repo_filename_index + self._index_manager = app.index_manager + self._repo_data = app.repo_data + + def start(self): + RepoFilenameIndexUpdaterTimer( + self._repo_status_filename_index, self._repo_filename_index, self._index_manager, self._repo_data + ).start() + + +def clear_deleted_repo(repo_status_filename_index, repo_filename_index, index_manager, repos): + logger.info("start to clear filename index deleted repo") + + repo_list = repo_status_filename_index.get_all_repos_from_index() + repo_all = [e.get('repo_id') for e in repo_list] + + repo_deleted = set(repo_all) - set(repos) + + logger.info("filename index %d repos need to be deleted." % len(repo_deleted)) + for repo_id in repo_deleted: + index_manager.delete_repo_filename_index(repo_id, repo_filename_index, repo_status_filename_index) + logger.info('Repo %s has been deleted from filename index.' % repo_id) + logger.info("filename index deleted repo has been cleared") + + +def update_repo_file_name_indexes(repo_status_filename_index, repo_filename_index, index_manager, repo_data): + start, count = 0, 1000 + all_repos = [] + while True: + try: + repo_commits = repo_data.get_repo_id_commit_id(start, count) + except Exception as e: + logger.error("Error: %s" % e) + return + start += 1000 + + if len(repo_commits) == 0: + break + + for repo_id, commit_id in repo_commits: + all_repos.append(repo_id) + + index_manager.update_library_filename_index(repo_id, commit_id, repo_filename_index, repo_status_filename_index) + + logger.info("Finish update filename index") + + clear_deleted_repo(repo_status_filename_index, repo_filename_index, index_manager, all_repos) + + +class RepoFilenameIndexUpdaterTimer(Thread): + def __init__(self, repo_status_filename_index, repo_filename_index, index_manager, repo_data): + super(RepoFilenameIndexUpdaterTimer, self).__init__() + self.repo_status_filename_index = repo_status_filename_index + self.repo_filename_index = repo_filename_index + self.index_manager = index_manager + self.repo_data = repo_data + + def run(self): + sched = GeventScheduler() + logging.info('Start to update filename index...') + try: + sched.add_job(update_repo_file_name_indexes, CronTrigger(minute='*/5'), + args=(self.repo_status_filename_index, self.repo_filename_index, self.index_manager, self.repo_data)) + except Exception as e: + logging.exception('periodical update filename index error: %s', e) + + sched.start() + +repo_filename_index_updater = RepoFilenameIndexUpdater() diff --git a/semantic_search/index_task/index_task_manager.py b/semantic_search/index_task/index_task_manager.py new file mode 100644 index 00000000..8c15df95 --- /dev/null +++ b/semantic_search/index_task/index_task_manager.py @@ -0,0 +1,207 @@ +import logging +import queue +import uuid +from datetime import datetime +from threading import Thread, Lock + +from apscheduler.triggers.cron import CronTrigger +from apscheduler.schedulers.gevent import GeventScheduler + +from seafevents.semantic_search import config + + +logger = logging.getLogger(__name__) + + +class IndexTask: + + def __init__(self, task_id, readable_id, func, args): + self.id = task_id + self.readable_id = readable_id + self.func = func + self.args = args + + self.status = 'init' + + self.started_at = None + self.finished_at = None + + self.result = None + self.error = None + + @staticmethod + def get_readable_id(readable_id): + return readable_id + + def run(self): + self.status = 'running' + self.started_at = datetime.now() + return self.func(*self.args) + + def set_result(self, result): + self.result = result + self.status = 'success' + self.finished_at = datetime.now() + + def set_error(self, error): + self.error = error + self.status = 'error' + self.finished_at = datetime.now() + + def is_finished(self): + return self.status in ['error', 'success'] + + def get_cost_time(self): + if self.started_at and self.finished_at: + return (self.finished_at - self.started_at).seconds + return None + + def get_info(self): + return f'{self.id}--{self.readable_id}--{self.func}' + + def __str__(self): + return f'' + + +class IndexTaskManager: + + def __init__(self): + self.tasks_queue = queue.Queue() + self.tasks_map = {} # {task_id: task} all tasks + self.readable_id2task_map = {} # {task_readable_id: task} in queue or running + self.check_task_lock = Lock() # lock access to readable_id2task_map + self.sched = GeventScheduler() + self.app = None + self.conf = { + 'workers': config.INDEX_MANAGER_WORKERS, + 'expire_time': config.INDEX_TASK_EXPIRE_TIME + } + self.sched.add_job(self.clear_expired_tasks, CronTrigger(minute='*/10')) + self.sched.add_job(self.cron_update_library_sdoc_indexes, CronTrigger(hour='*')) + + def init(self, app): + self.app = app + + def get_pending_or_running_task(self, readable_id): + task = self.readable_id2task_map.get(readable_id) + return task + + def add_library_sdoc_index_task(self, repo_id, commit_id): + readable_id = repo_id + with self.check_task_lock: + task = self.get_pending_or_running_task(readable_id) + if task: + return task.id + + task_id = str(uuid.uuid4()) + task = IndexTask(task_id, readable_id, self.app.index_manager.create_library_sdoc_index, + (repo_id, self.app.embedding_api, self.app.repo_file_index, self.app.repo_status_index, commit_id) + ) + + self.tasks_map[task_id] = task + self.readable_id2task_map[task.readable_id] = task + + self.tasks_queue.put(task) + return task_id + + def keyword_search(self, query, repos, count, suffixes): + return self.app.index_manager.keyword_search(query, repos, self.app.repo_filename_index, count, suffixes) + + def hybrid_search(self, query, repo, count): + return self.app.index_manager.hybrid_search(query, repo, self.app.repo_filename_index, + self.app.embedding_api, self.app.repo_file_index, count) + + def add_update_a_library_sdoc_index_task(self, repo_id, commit_id): + readable_id = repo_id + with self.check_task_lock: + task = self.get_pending_or_running_task(readable_id) + if task: + return task.id + + task_id = str(uuid.uuid4()) + task = IndexTask(task_id, readable_id, self.app.index_manager.update_library_sdoc_index, + (repo_id, self.app.embedding_api, self.app.repo_file_index, self.app.repo_status_index, + commit_id) + ) + self.tasks_map[task_id] = task + self.readable_id2task_map[task.readable_id] = task + self.tasks_queue.put(task) + + return task_id + + def update_library_sdoc_indexes(self): + index_repos = self.app.index_manager.list_index_repos() + for repo in index_repos: + repo_id = repo[0] + commit_id = self.app.repo_data.get_repo_head_commit(repo_id) + self.add_update_a_library_sdoc_index_task(repo_id, commit_id) + + def cron_update_library_sdoc_indexes(self): + """ + update library sdoc indexes periodly + query tasks and add them to queue by calling self.add_update_a_library_sdoc_index_task + """ + + try: + self.update_library_sdoc_indexes() + except Exception as e: + logger.exception('periodical update library sdoc indexes error: %s', e) + + def query_task(self, task_id): + return self.tasks_map.get(task_id) + + def handle_task(self): + while True: + try: + task = self.tasks_queue.get(timeout=2) + except queue.Empty: + continue + except Exception as e: + logger.error(e) + continue + + try: + task_info = task.get_info() + logger.info('Run task: %s' % task_info) + + # run + task.run() + task.set_result('success') + + logger.info('Run task success: %s cost %ds \n' % (task_info, task.get_cost_time())) + except Exception as e: + task.set_error(e) + logger.exception('Failed to handle task %s, error: %s \n' % (task.get_info(), e)) + finally: + with self.check_task_lock: + self.readable_id2task_map.pop(task.readable_id, None) + + def start(self): + thread_num = self.conf['workers'] + for i in range(thread_num): + t_name = 'IndexTaskManager Thread-' + str(i) + t = Thread(target=self.handle_task, name=t_name) + t.setDaemon(True) + t.start() + self.sched.start() + + def clear_expired_tasks(self): + """clear tasks finished for conf['expire_time'] in tasks_map + + when a task end, it will not be pop from tasks_map immediately, + because this task might be responsible for multi-http-requests(not only one), that might query task status + + but task will not restored forever, so need to clear + """ + expire_tasks = [] + for task in self.tasks_map.values(): + if not task.is_finished(): + continue + if (datetime.now() - task.finished_at).seconds >= self.conf['expire_time']: + expire_tasks.append(task) + logger.info('expired tasks: %s', len(expire_tasks)) + for task in expire_tasks: + self.tasks_map.pop(task.id, None) + + +index_task_manager = IndexTaskManager() diff --git a/semantic_search/script/portalocker/__init__.py b/semantic_search/script/portalocker/__init__.py new file mode 100644 index 00000000..5bae6b9f --- /dev/null +++ b/semantic_search/script/portalocker/__init__.py @@ -0,0 +1,3 @@ +from .portalocker import * +from .utils import * + diff --git a/semantic_search/script/portalocker/portalocker.py b/semantic_search/script/portalocker/portalocker.py new file mode 100644 index 00000000..b8dd3434 --- /dev/null +++ b/semantic_search/script/portalocker/portalocker.py @@ -0,0 +1,143 @@ +# portalocker.py - Cross-platform (posix/nt) API for flock-style file locking. +# Requires python 1.5.2 or better. +'''Cross-platform (posix/nt) API for flock-style file locking. + +Synopsis: + + import portalocker + file = open('somefile', 'r+') + portalocker.lock(file, portalocker.LOCK_EX) + file.seek(12) + file.write('foo') + file.close() + +If you know what you're doing, you may choose to + + portalocker.unlock(file) + +before closing the file, but why? + +Methods: + + lock( file, flags ) + unlock( file ) + +Constants: + + LOCK_EX + LOCK_SH + LOCK_NB + +Exceptions: + + LockException + +Notes: + +For the 'nt' platform, this module requires the Python Extensions for Windows. +Be aware that this may not work as expected on Windows 95/98/ME. + +History: + +I learned the win32 technique for locking files from sample code +provided by John Nielsen in the documentation +that accompanies the win32 modules. + +Author: Jonathan Feinberg , + Lowell Alleman +Version: $Id: portalocker.py 5474 2008-05-16 20:53:50Z lowell $ + +''' + + +__all__ = [ + 'lock', + 'unlock', + 'LOCK_EX', + 'LOCK_SH', + 'LOCK_NB', + 'LockException', +] + +import os +import logging + +logger = logging.getLogger('semantic_search') + + +class LockException(Exception): + # Error codes: + LOCK_FAILED = 1 + +if os.name == 'nt': + import win32con + import win32file + import pywintypes + LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK + LOCK_SH = 0 # the default + LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY + # is there any reason not to reuse the following structure? + __overlapped = pywintypes.OVERLAPPED() +elif os.name == 'posix': + import fcntl + LOCK_EX = fcntl.LOCK_EX + LOCK_SH = fcntl.LOCK_SH + LOCK_NB = fcntl.LOCK_NB +else: + logger.critical('PortaLocker only defined for nt and posix platforms') + raise RuntimeError('PortaLocker only defined for nt and posix platforms') + +if os.name == 'nt': + def lock(file, flags): + hfile = win32file._get_osfhandle(file.fileno()) + try: + win32file.LockFileEx(hfile, flags, 0, -0x10000, __overlapped) + except pywintypes.error as exc_value: + # error: (33, 'LockFileEx', 'The process cannot access the file because another process has locked a portion of the file.') + if exc_value[0] == 33: + raise LockException(LockException.LOCK_FAILED, exc_value[2]) + else: + # Q: Are there exceptions/codes we should be dealing with here? + raise + + def unlock(file): + hfile = win32file._get_osfhandle(file.fileno()) + try: + win32file.UnlockFileEx(hfile, 0, -0x10000, __overlapped) + except pywintypes.error as exc_value: + if exc_value[0] == 158: + # error: (158, 'UnlockFileEx', 'The segment is already unlocked.') + # To match the 'posix' implementation, silently ignore this error + pass + else: + # Q: Are there exceptions/codes we should be dealing with here? + raise + +elif os.name == 'posix': + def lock(file, flags): + try: + fcntl.flock(file.fileno(), flags) + except IOError as exc_value: + # The exception code varies on different systems so we'll catch + # every IO error + raise LockException(exc_value) + + def unlock(file): + fcntl.flock(file.fileno(), fcntl.LOCK_UN) + + +if __name__ == '__main__': + from time import time, strftime, localtime + import sys + import portalocker + + log = open('log.txt', 'a+') + portalocker.lock(log, portalocker.LOCK_EX) + + timestamp = strftime('%m/%d/%Y %H:%M:%S\n', localtime(time())) + log.write( timestamp ) + + print('Wrote lines. Hit enter to release lock.') + dummy = sys.stdin.readline() + + log.close() diff --git a/semantic_search/script/portalocker/utils.py b/semantic_search/script/portalocker/utils.py new file mode 100644 index 00000000..8f71eac0 --- /dev/null +++ b/semantic_search/script/portalocker/utils.py @@ -0,0 +1,143 @@ + +import time +from . import portalocker + +DEFAULT_TIMEOUT = 5 +DEFAULT_CHECK_INTERVAL = 0.25 +LOCK_METHOD = portalocker.LOCK_EX | portalocker.LOCK_NB + +__all__ = [ + 'Lock', + 'AlreadyLocked', +] + +class AlreadyLocked(Exception): + pass + +class Lock(object): + def __init__( + self, + filename, + mode='a', + truncate=0, + timeout=DEFAULT_TIMEOUT, + check_interval=DEFAULT_CHECK_INTERVAL, + fail_when_locked=True, + ): + '''Lock manager with build-in timeout + + filename -- filename + mode -- the open mode, 'a' or 'ab' should be used for writing + truncate -- use truncate to emulate 'w' mode, None is disabled, 0 is + truncate to 0 bytes + timeout -- timeout when trying to acquire a lock + check_interval -- check interval while waiting + fail_when_locked -- after the initial lock failed, return an error + or lock the file + + fail_when_locked is useful when multiple threads/processes can race + when creating a file. If set to true than the system will wait till + the lock was acquired and then return an AlreadyLocked exception. + + Note that the file is opened first and locked later. So using 'w' as + mode will result in truncate _BEFORE_ the lock is checked. + ''' + + self.fh = None + self.filename = filename + self.mode = mode + self.truncate = truncate + self.timeout = timeout + self.check_interval = check_interval + self.fail_when_locked = fail_when_locked + + assert 'w' not in mode, 'Mode "w" clears the file before locking' + + def acquire(self, timeout=None, check_interval=None, fail_when_locked=None): + '''Acquire the locked filehandle''' + if timeout is None: + timeout = self.timeout + + if check_interval is None: + check_interval = self.check_interval + + if fail_when_locked is None: + fail_when_locked = self.fail_when_locked + + # If we already have a filehandle, return it + fh = self.fh + if fh: + return fh + + # Get a new filehandler + fh = self._get_fh() + try: + # Try to lock + fh = self._get_lock(fh) + except portalocker.LockException as exception: + # Try till the timeout is 0 + while timeout > 0: + # Wait a bit + time.sleep(check_interval) + timeout -= check_interval + + # Try again + try: + fh = self._get_lock(fh) + + # We've got the lock, now return an error if + # fail_when_locked is True or break if not + if fail_when_locked: + self._release_lock() + raise AlreadyLocked(*exception) + else: + break + except portalocker.LockException: + pass + + else: + # We got a timeout... reraising + raise portalocker.LockException(*exception) + + # Prepare the filehandle (truncate if needed) + fh = self._prepare_fh(fh) + + self.fh = fh + return fh + + def _get_fh(self): + '''Get a new filehandle''' + return open(self.filename, self.mode) + + def _get_lock(self, fh): + ''' + Try to lock the given filehandle + + returns LockException if it fails''' + portalocker.lock(fh, LOCK_METHOD) + return fh + + def _prepare_fh(self, fh, truncate=None): + ''' + Prepare the filehandle for usage + + If truncate is a number, the file will be truncated to that amount of + bytes + ''' + if truncate is None: + truncate = self.truncate + + if truncate is not None: + fh.seek(truncate) + fh.truncate(truncate) + + return fh + + def __enter__(self): + self.fh = self.acquire() + return self.fh + + def __exit__(self, type, value, tb): + if self.fh: + self.fh.close() + diff --git a/semantic_search/script/repo_file_index_local.py b/semantic_search/script/repo_file_index_local.py new file mode 100644 index 00000000..16f4a08c --- /dev/null +++ b/semantic_search/script/repo_file_index_local.py @@ -0,0 +1,275 @@ +import os +import sys +import time +import queue +import logging +import argparse +import threading + +from seafobj import commit_mgr, fs_mgr, block_mgr + +import config +from seafevents.semantic_search.utils import init_logging +from seafevents.repo_data import repo_data +from seafevents.semantic_search.index_store.index_manager import IndexManager +from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI +from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex +from seafevents.semantic_search.index_store.repo_file_index import RepoFileIndex +from seafevents.semantic_search.utils.constants import REPO_STATUS_FILE_INDEX_NAME +from seafevents.semantic_search.utils.sea_embedding_api import SeaEmbeddingAPI + + +MAX_ERRORS_ALLOWED = 1000 +logger = logging.getLogger('semantic_search') + +UPDATE_FILE_LOCK = os.path.join(os.path.dirname(__file__), 'update.lock') +lockfile = None +NO_TASKS = False + + +class RepoFileIndexLocal(object): + """ Independent update repo file index. + """ + def __init__(self, index_manager, repo_status_index, repo_file_index, embedding_api, repo_data, workers=3): + self.index_manager = index_manager + self.repo_status_index = repo_status_index + self.repo_file_index = repo_file_index + self.embedding_api = embedding_api + self.repo_data = repo_data + self.error_counter = 0 + self.worker_list = [] + self.workers = workers + + def clear_worker(self): + for th in self.worker_list: + th.join() + logger.info("All worker threads has stopped.") + + def run(self): + time_start = time.time() + repos_queue = queue.Queue(0) + for i in range(self.workers): + thread_name = "worker" + str(i) + logger.info("starting %s worker threads for repo file indexing" + % thread_name) + t = threading.Thread(target=self.thread_task, args=(repos_queue, ), name=thread_name) + t.start() + self.worker_list.append(t) + + start, per_size = 0, 1000 + need_deleted_index_repos = [] + while True: + global NO_TASKS + try: + index_repos = list(self.index_manager.get_index_repos_by_size(start, per_size)) + except Exception as e: + logger.error("Error: %s" % e) + NO_TASKS = True + self.clear_worker() + break + else: + if len(index_repos) == 0: + NO_TASKS = True + break + + for index_repo in index_repos: + repo_id = index_repo[0] + commit_id = self.repo_data.get_repo_head_commit(repo_id) + if not commit_id: + # repo has deleted, delete repo index + need_deleted_index_repos.append(repo_id) + continue + repos_queue.put((repo_id, commit_id)) + + start += per_size + + self.clear_worker() + logger.info("repo file index updated, total time %s seconds" % str(time.time() - time_start)) + try: + self.clear_deleted_repo(need_deleted_index_repos) + except Exception as e: + logger.exception('Delete Repo Error: %s' % e) + self.incr_error() + + def thread_task(self, repos_queue): + while True: + try: + queue_data = repos_queue.get(False) + except queue.Empty: + if NO_TASKS: + logger.debug( + "Queue is empty, %s worker threads stop" + % (threading.currentThread().getName()) + ) + break + else: + time.sleep(2) + else: + repo_id = queue_data[0] + commit_id = queue_data[1] + try: + self.index_manager.create_library_sdoc_index(repo_id, self.embedding_api, self.repo_file_index, self.repo_status_index, commit_id) + except Exception as e: + logger.exception('Repo file index error: %s, repo_id: %s' % (e, repo_id), exc_info=True) + self.incr_error() + + logger.info( + "%s worker updated at %s time" + % (threading.currentThread().getName(), + time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))) + ) + logger.info( + "%s worker get %s error" + % (threading.currentThread().getName(), + str(self.error_counter)) + ) + + def clear_deleted_repo(self, repos): + logger.info("start to clear deleted repo") + logger.info("%d repos need to be deleted." % len(repos)) + + for repo_id in repos: + self.delete_repo(repo_id) + logger.info('Repo %s has been deleted from index.' % repo_id) + logger.info("deleted repo has been cleared") + + def incr_error(self): + self.error_counter += 1 + + def delete_repo(self, repo_id): + if len(repo_id) != 36: + return + self.index_manager.delete_index_repo_db(repo_id) + + +def start_index_local(): + if not check_concurrent_update(): + return + + seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + index_manager = IndexManager() + repo_status_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILE_INDEX_NAME) + repo_file_index = RepoFileIndex(seasearch_api) + + embedding_api = SeaEmbeddingAPI(config.APP_NAME, config.SEA_EMBEDDING_SERVER) + workers = config.INDEX_MANAGER_WORKERS + + try: + index_local = RepoFileIndexLocal(index_manager, repo_status_index, repo_file_index, embedding_api, repo_data, workers) + except Exception as e: + logger.error("Index repo file process init error: %s." % e) + return + + logger.info("Index repo file process initialized.") + index_local.run() + + logger.info('\n\nRepo file index updated, statistic report:\n') + logger.info('[commit read] %s', commit_mgr.read_count()) + logger.info('[dir read] %s', fs_mgr.dir_read_count()) + logger.info('[file read] %s', fs_mgr.file_read_count()) + logger.info('[block read] %s', block_mgr.read_count()) + + +def delete_indices(): + seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + repo_status_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILE_INDEX_NAME) + repo_file_index = RepoFileIndex(seasearch_api) + index_manager = IndexManager() + + start, per_size = 0, 1000 + while True: + index_repos = list(index_manager.get_index_repos_by_size(start, per_size)) + + if len(index_repos) == 0: + break + + for index_repo in index_repos: + repo_file_index.delete_index_by_index_name(index_repo[0]) + start += per_size + + repo_status_index.delete_index_by_index_name() + + +def main(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(title='subcommands', description='') + + parser.add_argument( + '--logfile', + default=sys.stdout, + type=argparse.FileType('a'), + help='log file') + + parser.add_argument( + '--loglevel', + default='info', + help='log level') + + # update index + parser_update = subparsers.add_parser('update', help='update repo file index') + parser_update.set_defaults(func=start_index_local) + + # clear + parser_clear = subparsers.add_parser('clear', help='clear all repo file index') + parser_clear.set_defaults(func=delete_indices) + + if len(sys.argv) == 1: + print(parser.format_help()) + return + + args = parser.parse_args() + init_logging(args) + + logger.info('storage: using ' + commit_mgr.get_backend_name()) + + args.func() + + +def do_lock(fn): + if os.name == 'nt': + return do_lock_win32(fn) + else: + return do_lock_linux(fn) + + +def do_lock_win32(fn): + import ctypes + + CreateFileW = ctypes.windll.kernel32.CreateFileW + GENERIC_WRITE = 0x40000000 + OPEN_ALWAYS = 4 + + def lock_file(path): + lock_file_handle = CreateFileW(path, GENERIC_WRITE, 0, None, OPEN_ALWAYS, 0, None) + + return lock_file_handle + + global lockfile + + lockfile = lock_file(fn) + + return lockfile != -1 + + +def do_lock_linux(fn): + from seafevents.semantic_search import portalocker + global lockfile + lockfile = open(fn, 'w') + try: + portalocker.lock(lockfile, portalocker.LOCK_NB | portalocker.LOCK_EX) + return True + except portalocker.LockException: + return False + + +def check_concurrent_update(): + """Use a lock file to ensure only one task can be running""" + if not do_lock(UPDATE_FILE_LOCK): + logger.error('another index task is running, quit now') + return False + + return True + + +if __name__ == "__main__": + main() diff --git a/semantic_search/script/repo_filename_index_local.py b/semantic_search/script/repo_filename_index_local.py new file mode 100644 index 00000000..538529cd --- /dev/null +++ b/semantic_search/script/repo_filename_index_local.py @@ -0,0 +1,271 @@ +import os +import sys +import time +import queue +import logging +import argparse +import threading + +from seafobj import commit_mgr, fs_mgr, block_mgr +import config +from seafevents.semantic_search.utils import init_logging +from seafevents.repo_data import repo_data +from seafevents.semantic_search.index_store.index_manager import IndexManager +from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI +from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex +from seafevents.semantic_search.utils.constants import REPO_STATUS_FILENAME_INDEX_NAME, REPO_FILENAME_INDEX_PREFIX +from seafevents.semantic_search.index_store.repo_file_name_index import RepoFileNameIndex + +MAX_ERRORS_ALLOWED = 1000 +logger = logging.getLogger('semantic_search') + +UPDATE_FILE_LOCK = os.path.join(os.path.dirname(__file__), 'update.lock') +lockfile = None +NO_TASKS = False + + +class RepoFileNameIndexLocal(object): + """ Independent update repo file name index. + """ + def __init__(self, index_manager, repo_status_filename_index, repo_filename_index, repo_data, workers=3): + self.index_manager = index_manager + self.repo_status_filename_index = repo_status_filename_index + self.repo_filename_index = repo_filename_index + self.repo_data = repo_data + self.error_counter = 0 + self.worker_list = [] + self.workers = workers + + def clear_worker(self): + for th in self.worker_list: + th.join() + logger.info("All worker threads has stopped.") + + def run(self): + time_start = time.time() + repos_queue = queue.Queue(0) + for i in range(self.workers): + thread_name = "worker" + str(i) + logger.info("starting %s worker threads for repo filename indexing" + % thread_name) + t = threading.Thread(target=self.thread_task, args=(repos_queue, ), name=thread_name) + t.start() + self.worker_list.append(t) + + start, per_size = 0, 1000 + repos = {} + while True: + global NO_TASKS + try: + repo_commits = self.repo_data.get_repo_id_commit_id(start, per_size) + except Exception as e: + logger.error("Error: %s" % e) + NO_TASKS = True + self.clear_worker() + return + else: + if len(repo_commits) == 0: + NO_TASKS = True + break + for repo_id, commit_id in repo_commits.items(): + repos_queue.put((repo_id, commit_id)) + repos[repo_id] = commit_id + start += per_size + + self.clear_worker() + logger.info("repo filename index updated, total time %s seconds" % str(time.time() - time_start)) + try: + self.clear_deleted_repo(list(repos.keys())) + except Exception as e: + logger.exception('Delete Repo Error: %s' % e) + self.incr_error() + + def thread_task(self, repos_queue): + while True: + try: + queue_data = repos_queue.get(False) + except queue.Empty: + if NO_TASKS: + logger.debug( + "Queue is empty, %s worker threads stop" + % (threading.currentThread().getName()) + ) + break + else: + time.sleep(2) + else: + repo_id = queue_data[0] + commit_id = queue_data[1] + try: + self.index_manager.update_library_filename_index(repo_id, commit_id, self.repo_filename_index, self.repo_status_filename_index) + except Exception as e: + logger.exception('Repo filename index error: %s, repo_id: %s' % (e, repo_id), exc_info=True) + self.incr_error() + + logger.info( + "%s worker updated at %s time" + % (threading.currentThread().getName(), + time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))) + ) + logger.info( + "%s worker get %s error" + % (threading.currentThread().getName(), + str(self.error_counter)) + ) + + def clear_deleted_repo(self, repos): + logger.info("start to clear deleted repo") + repo_all = [e.get('repo_id') for e in self.repo_status_filename_index.get_all_repos_from_index()] + + repo_deleted = set(repo_all) - set(repos) + logger.info("%d repos need to be deleted." % len(repo_deleted)) + + for repo_id in repo_deleted: + self.delete_repo(repo_id) + logger.info('Repo %s has been deleted from index.' % repo_id) + logger.info("deleted repo has been cleared") + + def incr_error(self): + self.error_counter += 1 + + def delete_repo(self, repo_id): + if len(repo_id) != 36: + return + self.index_manager.delete_repo_filename_index(repo_id, self.repo_filename_index, self.repo_status_filename_index) + + +def start_index_local(): + if not check_concurrent_update(): + return + + index_manager = IndexManager() + seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + repo_status_filename_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) + + repo_filename_index = RepoFileNameIndex(seasearch_api, repo_data) + + workers = config.INDEX_MANAGER_WORKERS + + try: + index_local = RepoFileNameIndexLocal(index_manager, repo_status_filename_index, repo_filename_index,repo_data, workers) + except Exception as e: + logger.error("Index repo filename process init error: %s." % e) + return + + logger.info("Index repo filename process initialized.") + index_local.run() + + logger.info('\n\nRepo filename index updated, statistic report:\n') + logger.info('[commit read] %s', commit_mgr.read_count()) + logger.info('[dir read] %s', fs_mgr.dir_read_count()) + logger.info('[file read] %s', fs_mgr.file_read_count()) + logger.info('[block read] %s', block_mgr.read_count()) + + +def delete_indices(): + seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + repo_status_filename_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) + repo_filename_index = RepoFileNameIndex(seasearch_api, repo_data) + + start, count = 0, 1000 + while True: + try: + repo_commits = repo_data.get_repo_id_commit_id(start, count) + except Exception as e: + logger.error("Error: %s" % e) + return + start += 1000 + + if len(repo_commits) == 0: + break + + for repo_id, commit_id in repo_commits.items(): + repo_filename_index_name = REPO_FILENAME_INDEX_PREFIX + repo_id + repo_filename_index.delete_index_by_index_name(repo_filename_index_name) + + repo_status_filename_index.delete_index_by_index_name() + + +def main(): + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(title='subcommands', description='') + + parser.add_argument( + '--logfile', + default=sys.stdout, + type=argparse.FileType('a'), + help='log file') + + parser.add_argument( + '--loglevel', + default='info', + help='log level') + + # update index + parser_update = subparsers.add_parser('update', help='update seafile repo filename index') + parser_update.set_defaults(func=start_index_local) + + # clear + parser_clear = subparsers.add_parser('clear', help='clear all repo filename index') + parser_clear.set_defaults(func=delete_indices) + + if len(sys.argv) == 1: + print(parser.format_help()) + return + + args = parser.parse_args() + init_logging(args) + + logger.info('storage: using ' + commit_mgr.get_backend_name()) + + args.func() + + +def do_lock(fn): + if os.name == 'nt': + return do_lock_win32(fn) + else: + return do_lock_linux(fn) + + +def do_lock_win32(fn): + import ctypes + + CreateFileW = ctypes.windll.kernel32.CreateFileW + GENERIC_WRITE = 0x40000000 + OPEN_ALWAYS = 4 + + def lock_file(path): + lock_file_handle = CreateFileW(path, GENERIC_WRITE, 0, None, OPEN_ALWAYS, 0, None) + + return lock_file_handle + + global lockfile + + lockfile = lock_file(fn) + + return lockfile != -1 + + +def do_lock_linux(fn): + from seafevents.semantic_search import portalocker + global lockfile + lockfile = open(fn, 'w') + try: + portalocker.lock(lockfile, portalocker.LOCK_NB | portalocker.LOCK_EX) + return True + except portalocker.LockException: + return False + + +def check_concurrent_update(): + """Use a lock file to ensure only one task can be running""" + if not do_lock(UPDATE_FILE_LOCK): + logger.error('another index task is running, quit now') + return False + + return True + + +if __name__ == "__main__": + main() diff --git a/semantic_search/semantic_search.py b/semantic_search/semantic_search.py new file mode 100644 index 00000000..d3dadbd7 --- /dev/null +++ b/semantic_search/semantic_search.py @@ -0,0 +1,50 @@ +import logging + +from seafevents.semantic_search.index_task.index_task_manager import index_task_manager +from seafevents.semantic_search.index_task.filename_index_updater import repo_filename_index_updater +from seafevents.semantic_search.index_store.index_manager import IndexManager +from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex +from seafevents.semantic_search.index_store.repo_file_index import RepoFileIndex +from seafevents.semantic_search.index_store.repo_file_name_index import RepoFileNameIndex +from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI +from seafevents.semantic_search.utils.sea_embedding_api import SeaEmbeddingAPI +from seafevents.semantic_search.utils.constants import REPO_STATUS_FILE_INDEX_NAME, REPO_STATUS_FILENAME_INDEX_NAME +from seafevents.repo_data import repo_data +from seafevents.semantic_search import config + +logger = logging.getLogger(__name__) + +class SemanticSearch(): + def __init__(self): + self.index_manager = None + self.seasearch_api = None + self.repo_data = None + self.embedding_api = None + + # for semantic search + self.repo_status_index = None + self.repo_file_index = None + + # for keyword search + self.repo_status_filename_index = None + self.repo_filename_index = None + self.index_task_manager = None + self.repo_filename_index_updater = None + + def init(self): + self.index_manager = IndexManager() + self.seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + self.repo_data = repo_data + self.embedding_api = SeaEmbeddingAPI(config.APP_NAME, config.SEA_EMBEDDING_SERVER) + + # for semantic search + self.repo_status_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILE_INDEX_NAME) + self.repo_file_index = RepoFileIndex(self.seasearch_api) + + # for keyword search + self.repo_status_filename_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) + self.repo_filename_index = RepoFileNameIndex(self.seasearch_api, self.repo_data) + self.index_task_manager = index_task_manager + self.repo_filename_index_updater = repo_filename_index_updater + +sem_app = SemanticSearch() diff --git a/semantic_search/semantic_search_settings.py b/semantic_search/semantic_search_settings.py new file mode 100644 index 00000000..9e514029 --- /dev/null +++ b/semantic_search/semantic_search_settings.py @@ -0,0 +1,22 @@ + + +# ENABLE_SYS_LOG = True + +INDEX_STORAGE_PATH = '/data/dev/static/index/' + +MODEL_CACHE_DIR = '/data/dev/static/' + +RETRIEVAL_NUM = 50 + +MODEL_VOCAB_PATH = '/data/dev/static/damo/nlp_corom_sentence-embedding_chinese-base/' + + + +## sea-embedding +SEA_EMBEDDING_SERVER = 'http://host.docker.internal:8889' +SEA_EMBEDDING_KEY = '123' + + +#seasech +SEASEARCH_SERVER = 'http://host.docker.internal:4080' +SEASEARCH_TOKEN = 'YWRtaW46Q29tcGxleHBhc3MjMTIz' diff --git a/semantic_search/utils/__init__.py b/semantic_search/utils/__init__.py new file mode 100644 index 00000000..50fa98da --- /dev/null +++ b/semantic_search/utils/__init__.py @@ -0,0 +1,88 @@ +import logging +import hashlib + +from seafevents.semantic_search.utils.commit_differ import CommitDiffer + +from seafobj import fs_mgr, commit_mgr +from seafobj.exceptions import GetObjectError +from seaserv import seafile_api + + +logger = logging.getLogger(__name__) + +SYS_DIRS = ['images', '_Internal'] + + +def get_library_diff_files(repo_id, old_commit_id, new_commit_id): + if old_commit_id == new_commit_id: + return [], [], [], [], [] + + old_root = None + if old_commit_id: + try: + old_commit = commit_mgr.load_commit(repo_id, 0, old_commit_id) + old_root = old_commit.root_id + except GetObjectError as e: + logger.debug(e) + old_root = None + + try: + new_commit = commit_mgr.load_commit(repo_id, 0, new_commit_id) + except GetObjectError as e: + # new commit should exists in the obj store + logger.warning(e) + return [], [], [], [], [] + + new_root = new_commit.root_id + version = new_commit.get_version() + + try: + differ = CommitDiffer(repo_id, version, old_root, new_root) + added_files, deleted_files, added_dirs, deleted_dirs, modified_files = differ.diff(new_commit.ctime) + except Exception as e: + logger.warning('differ error: %s' % e) + return [], [], [], [], [] + + return added_files, deleted_files, modified_files, added_dirs, deleted_dirs + + +def init_logging(args): + level = args.loglevel + + if level == 'debug': + level = logging.DEBUG + elif level == 'info': + level = logging.INFO + elif level == 'warning': + level = logging.WARNING + else: + level = logging.INFO + + try: + # set boto3 log level + import boto3 + boto3.set_stream_logger(level=logging.WARNING) + except: + pass + + kw = { + 'format': '%(asctime)s [%(levelname)s] %(name)s:%(lineno)s %(funcName)s: %(message)s', + 'datefmt': '%m/%d/%Y %H:%M:%S', + 'level': level, + 'stream': args.logfile + } + + logging.basicConfig(**kw) + logging.getLogger('oss_util').setLevel(logging.WARNING) + logging.getLogger('urllib3').setLevel(logging.WARNING) + logging.getLogger("requests").setLevel(logging.WARNING) + + +def md5(text): + return hashlib.md5(text.encode()).hexdigest() + + +def is_sys_dir_or_file(path): + if path.split('/')[1] in SYS_DIRS: + return True + return False diff --git a/semantic_search/utils/commit_differ.py b/semantic_search/utils/commit_differ.py new file mode 100644 index 00000000..a8957ffe --- /dev/null +++ b/semantic_search/utils/commit_differ.py @@ -0,0 +1,105 @@ +# coding: UTF-8 +from seafevents.semantic_search.utils.constants import ZERO_OBJ_ID + +from seafobj import fs_mgr + + +class CommitDiffer(object): + def __init__(self, repo_id, version, root1, root2): + self.repo_id = repo_id + self.version = version + self.root1 = root1 + self.root2 = root2 + + def diff(self, root2_time): # noqa: C901 + added_files = [] + deleted_files = [] + deleted_dirs = [] + modified_files = [] + added_dirs = [] + + new_dirs = [] # (path, dir_id) + queued_dirs = [] # (path, dir_id1, dir_id2) + + if ZERO_OBJ_ID == self.root1: + self.root1 = None + if ZERO_OBJ_ID == self.root2: + self.root2 = None + + if self.root1 == self.root2: + return (added_files, deleted_files, added_dirs, deleted_dirs, + modified_files) + elif not self.root1: + new_dirs.append(('/', self.root2, root2_time, None)) + elif not self.root2: + deleted_dirs.append('/') + else: + queued_dirs.append(('/', self.root1, self.root2)) + + while True: + path = old_id = new_id = None + try: + path, old_id, new_id = queued_dirs.pop(0) + except IndexError: + break + + dir1 = fs_mgr.load_seafdir(self.repo_id, self.version, old_id) + dir2 = fs_mgr.load_seafdir(self.repo_id, self.version, new_id) + + for dent in dir1.get_files_list(): + new_dent = dir2.lookup_dent(dent.name) + if not new_dent or new_dent.type != dent.type: + deleted_files.append((make_path(path, dent.name), )) + else: + dir2.remove_entry(dent.name) + if new_dent.id == dent.id: + pass + else: + modified_files.append((make_path(path, dent.name), new_dent.id, new_dent.mtime, new_dent.size)) + + added_files.extend([(make_path(path, dent.name), dent.id, dent.mtime, dent.size) for dent in dir2.get_files_list()]) + + for dent in dir1.get_subdirs_list(): + new_dent = dir2.lookup_dent(dent.name) + if not new_dent or new_dent.type != dent.type: + deleted_dirs.append(make_path(path, dent.name)) + else: + dir2.remove_entry(dent.name) + if new_dent.id == dent.id: + pass + else: + queued_dirs.append((make_path(path, dent.name), dent.id, new_dent.id)) + + new_dirs.extend([(make_path(path, dent.name), dent.id, dent.mtime, dent.size) for dent in dir2.get_subdirs_list()]) + + while True: + # Process newly added dirs and its sub-dirs, all files under + # these dirs should be marked as added. + path = obj_id = None + try: + path, obj_id, mtime, size = new_dirs.pop(0) + added_dirs.append((path, obj_id, mtime, size)) + except IndexError: + break + d = fs_mgr.load_seafdir(self.repo_id, self.version, obj_id) + added_files.extend([(make_path(path, dent.name), dent.id, dent.mtime, dent.size) for dent in d.get_files_list()]) + + new_dirs.extend([(make_path(path, dent.name), dent.id, dent.mtime, dent.size) for dent in d.get_subdirs_list()]) + + return (added_files, deleted_files, added_dirs, deleted_dirs, + modified_files) + + +def search_entry(entries, entryname): + for name, obj_id in entries: + if name == entryname: + entries.remove((name, obj_id)) + return (name, obj_id) + return (None, None) + + +def make_path(dirname, filename): + if dirname == '/': + return dirname + filename + else: + return '/'.join((dirname, filename)) diff --git a/semantic_search/utils/constants.py b/semantic_search/utils/constants.py new file mode 100644 index 00000000..272bf49c --- /dev/null +++ b/semantic_search/utils/constants.py @@ -0,0 +1,5 @@ +ZERO_OBJ_ID = '0000000000000000000000000000000000000000' + +REPO_STATUS_FILE_INDEX_NAME = 'repo_status_file' +REPO_STATUS_FILENAME_INDEX_NAME = 'repo_status_filename' +REPO_FILENAME_INDEX_PREFIX = 'filename_' diff --git a/semantic_search/utils/sea_embedding_api.py b/semantic_search/utils/sea_embedding_api.py new file mode 100644 index 00000000..002a5aa0 --- /dev/null +++ b/semantic_search/utils/sea_embedding_api.py @@ -0,0 +1,44 @@ +import logging +import requests +import jwt +import time +import json + +from seafevents.semantic_search.config import SEA_EMBEDDING_KEY + +logger = logging.getLogger(__name__) + + +def parse_response(response): + if response.status_code >= 400: + raise ConnectionError(response.status_code, response.text) + else: + try: + data = json.loads(response.text) + return data + except: + pass + + +class SeaEmbeddingAPI(object): + + def __init__(self, username, sea_embedding_url, time_out=180): + self.username = username + self.sea_embedding_url = sea_embedding_url.rstrip('/') + self.time_out = time_out + + def gen_headers(self): + payload = {'exp': int(time.time()) + 300, } + token = jwt.encode(payload, SEA_EMBEDDING_KEY, algorithm='HS256') + return {"Authorization": "Token %s" % token} + + def embeddings(self, input): + url = self.sea_embedding_url + '/api/v1/embeddings/' + params = { + 'input': input, + } + headers = self.gen_headers() + + response = requests.post(url, headers=headers, json=params, timeout=self.time_out) + data = parse_response(response) + return data diff --git a/semantic_search/utils/seasearch_api.py b/semantic_search/utils/seasearch_api.py new file mode 100644 index 00000000..103b6d02 --- /dev/null +++ b/semantic_search/utils/seasearch_api.py @@ -0,0 +1,136 @@ +import json +import logging +import requests +import ndjson + +logger = logging.getLogger(__name__) + + +def parse_response(response): + if response.status_code == 400: + logger.warning('seasearch error: %s', response.text) + if response.status_code > 400: + raise ConnectionError(response.status_code, response.text) + else: + try: + return json.loads(response.text) + except: + pass + + +class SeaSearchAPI(): + + def __init__(self, server, token, timeout=180): + self.token = token + self.server = server + self.timeout = timeout + self.gen_header() + + def gen_header(self): + self.headers = { + 'Authorization': 'Basic ' + self.token + } + + def create_index(self, index_name, data): + url = self.server + '/api/index/' + index_name + response = requests.put(url, headers=self.headers, json=data, timeout=self.timeout) + if response.status_code == 400: + raise Exception('create index: %s, error: %s' % (index_name, response.text)) + data = parse_response(response) + + return data + + def create_document_by_id(self, index_name, doc_id, date): + url = self.server + '/api/' + index_name + '/_doc/' + doc_id + response = requests.put(url, headers=self.headers, json=date, timeout=self.timeout) + if response.status_code == 400: + raise Exception('index: %s, add document: %s, error: %s' % (index_name, doc_id, response.text)) + data = parse_response(response) + + return data + + def bulk(self, index_name, data): + """ + this option includes add, update and delete index or document + """ + url = self.server + '/es/' + index_name + '/_bulk' + data = ndjson.dumps(data) + response = requests.post(url, headers=self.headers, data=data, timeout=self.timeout) + data = parse_response(response) + error = data.get('error') + if error: + raise Exception(error) + return data + + def vector_search(self, index_name, data): + url = self.server + '/api/' + index_name + '/_search/vector' + response = requests.post(url, headers=self.headers, json=data, timeout=self.timeout) + + return parse_response(response) + + def normal_search(self, index_name, data): + url = self.server + '/es/' + index_name + '/_search' + response = requests.post(url, headers=self.headers, json=data, timeout=self.timeout) + + return parse_response(response) + + def m_search(self, data, unify_score=True): + url = self.server + '/es/_msearch' + if unify_score: + url += '?unify_score=true' + data = ndjson.dumps(data) + response = requests.post(url, headers=self.headers, data=data, timeout=self.timeout) + return parse_response(response) + + def check_index_mapping(self, index_name): + url = self.server + '/es/' + index_name + '/_mapping' + + response = requests.get(url, headers=self.headers, timeout=self.timeout) + if response.status_code == 400: + return {'is_exist': False} + elif response.status_code > 400: + raise ConnectionError(response.status_code, response.text) + + return {'is_exist': True} + + def check_document_by_id(self, index_name, doc_id): + url = self.server + '/api/' + index_name + '/_doc/' + doc_id + response = requests.get(url, headers=self.headers, timeout=self.timeout) + if response.status_code == 400: + return {'is_exist': False} + elif response.status_code > 400: + raise ConnectionError(response.status_code, response.text) + + return {'is_exist': True} + + def get_document_by_id(self, index_name, doc_id): + url = self.server + '/api/' + index_name + '/_doc/' + doc_id + response = requests.get(url, headers=self.headers, timeout=self.timeout) + return parse_response(response) + + def delete_document_by_id(self, index_name, doc_id): + url = self.server + '/api/' + index_name + '/_doc/' + doc_id + response = requests.delete(url, headers=self.headers, timeout=self.timeout) + data = parse_response(response) + error = data.get('error') + if error: + logger.warning('delete_document_by_id error: %s', error) + return data + + def delete_index_by_name(self, index_name): + url = self.server + '/api/index/' + index_name + response = requests.delete(url, headers=self.headers, timeout=self.timeout) + if response.status_code == 400: + logger.warning('index: %s not exist error: %s' % (index_name, response.text)) + elif response.status_code > 400: + raise ConnectionError(response.status_code, response.text) + return json.loads(response.text) + + def update_document_by_id(self, index_name, doc_id, data): + url = self.server + '/api/' + index_name + '/_doc/' + doc_id + response = requests.put(url, headers=self.headers, json=data, timeout=self.timeout) + data = parse_response(response) + error = data.get('error') + if error: + raise Exception(error) + return data diff --git a/semantic_search/utils/text_splitter.py b/semantic_search/utils/text_splitter.py new file mode 100644 index 00000000..fb12ae17 --- /dev/null +++ b/semantic_search/utils/text_splitter.py @@ -0,0 +1,282 @@ +import re +import logging +from transformers import AutoTokenizer + +from seafevents.semantic_search.config import MODEL_VOCAB_PATH +logger = logging.getLogger(__name__) + + +tokenizer = AutoTokenizer.from_pretrained(MODEL_VOCAB_PATH) + + +class MarkdownHeaderTextSplitter: + """Splitting markdown files based on specified headers.""" + + def __init__( + self, headers_to_split_on + ): + """Create a new MarkdownHeaderTextSplitter. + + Args: + headers_to_split_on: Headers we want to track + """ + # Given the headers we want to split on, + # (e.g., "#, ##, etc") order by length + self.headers_to_split_on = sorted( + headers_to_split_on, key=lambda split: len(split[0]), reverse=True + ) + + def aggregate_lines_to_chunks(self, lines): + """Combine lines with common metadata into chunks + Args: + lines: Line of text / associated header metadata + """ + aggregated_chunks = [] + + for line in lines: + if ( + aggregated_chunks + and aggregated_chunks[-1]["header_name"] == line["header_name"] + ): + # If the last line in the aggregated list + # has the same metadata as the current line, + # append the current content to the last lines's content + aggregated_chunks[-1]["content"] += " \n" + line["content"] + else: + # Otherwise, append the current line to the aggregated list + aggregated_chunks.append(line) + + return aggregated_chunks + + def split_text(self, text): + """Split markdown file + Args: + text: Markdown file""" + + # Split the input text by newline character ("\n"). + lines = text.split("\n") + # Final output + lines_with_metadata = [] + # Content and metadata of the chunk currently being processed + current_content = [] + current_header_name = '' + header_name = '' + + in_code_block = False + headers = [] + for line in lines: + stripped_line = line.strip() + + if not stripped_line: + continue + + if stripped_line.startswith("```"): + in_code_block = not in_code_block + continue + + if in_code_block: + continue + + # Check each line against each of the header types (e.g., #, ##) + for sep in self.headers_to_split_on: + # Check if line starts with a header that we intend to split on + if stripped_line.startswith(sep) and ( + # Header with no text OR header is followed by space + # Both are valid conditions that sep is being used a header + len(stripped_line) == len(sep) + or stripped_line[len(sep)] == " " + ): + header_name = stripped_line[len(sep):].strip() + headers.append(header_name) + + # Add the previous line to the lines_with_metadata + # only if current_content is not empty + if current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "header_name": current_header_name, + } + ) + current_content.clear() + break + else: + if stripped_line: + current_content.extend(headers) + current_content.append(stripped_line) + headers = [] + elif current_content: + lines_with_metadata.append( + { + "content": "\n".join(current_content), + "header_name": current_header_name, + } + ) + current_content.clear() + current_header_name = header_name + + # If it ends with a separate title, headers is not empty + if headers: + lines_with_metadata.append( + {"content": "\n".join(headers), "header_name": current_header_name} + ) + + if current_content: + lines_with_metadata.append( + {"content": "\n".join(current_content), "header_name": current_header_name} + ) + + # lines_with_metadata has each line with associated header name + # aggregate these into chunks based on common header name + return self.aggregate_lines_to_chunks(lines_with_metadata) + + +def _split_text_with_regex(text, separator, keep_separator): + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = [_splits[i - 1] + _splits[i] for i in range(1, len(_splits), 2)] + splits = splits + _splits[-1:] + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + +def tokenizer_length(text): + tokens = tokenizer.tokenize(text) + return len(tokens) + + +class RecursiveCharacterTextSplitter(object): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one + that works. + """ + + def __init__( + self, + separators=None, + keep_separator: bool = True, + is_separator_regex: bool = True, + + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function=len, + strip_whitespace: bool = True + ) -> None: + """Create a new TextSplitter.""" + self._separators = separators or ["\n\n", "\n", " ", "。|!|?", "\.|\!|\?"] + self._is_separator_regex = is_separator_regex + + if chunk_overlap > chunk_size: + raise ValueError( + f"Got a larger chunk overlap ({chunk_overlap}) than chunk size " + f"({chunk_size}), should be smaller." + ) + self._chunk_size = chunk_size + self._chunk_overlap = chunk_overlap + self._length_function = length_function + self._keep_separator = keep_separator + self._strip_whitespace = strip_whitespace + + def _split_text(self, text: str, separators): + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1:] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = _split_text_with_regex(text, _separator, self._keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self._keep_separator else separator + for s in splits: + if self._length_function(s) < self._chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text): + return self._split_text(text, self._separators) + + def _join_docs(self, docs, separator): + if self._keep_separator: + text = separator.join(docs) + else: + text = ''.join(docs) + if self._strip_whitespace: + text = text.strip() + if text == "": + return None + else: + return text + + def _merge_splits(self, splits, separator): + # We now want to combine these smaller pieces into medium size + # chunks to send to the LLM. + separator_len = self._length_function(separator) + + docs = [] + current_doc = [] + total = 0 + for d in splits: + _len = self._length_function(d) + if ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + ): + if total > self._chunk_size: + logger.warning( + f"Created a chunk of size {total}, " + f"which is longer than the specified {self._chunk_size}" + ) + if len(current_doc) > 0: + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self._chunk_overlap or ( + total + _len + (separator_len if len(current_doc) > 0 else 0) + > self._chunk_size + and total > 0 + ): + total -= self._length_function(current_doc[0]) + ( + separator_len if len(current_doc) > 1 else 0 + ) + current_doc = current_doc[1:] + current_doc.append(d) + total += _len + (separator_len if len(current_doc) > 1 else 0) + doc = self._join_docs(current_doc, separator) + if doc is not None: + docs.append(doc) + return docs From 7a122fd3c5c438f1660a1039b8ddfb7518b91bcf Mon Sep 17 00:00:00 2001 From: cir9no <44470218+cir9no@users.noreply.github.com> Date: Wed, 17 Jul 2024 11:20:07 +0800 Subject: [PATCH 2/4] optimize conf --- semantic_search/config.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/semantic_search/config.py b/semantic_search/config.py index ad3a33d2..3156096a 100644 --- a/semantic_search/config.py +++ b/semantic_search/config.py @@ -1,8 +1,6 @@ import os import logging -from seafevents.app.config import get_config - logger = logging.getLogger(__name__) @@ -46,9 +44,6 @@ '.pdf', ] - -CONF_DIR = '/opt/seafile/conf/' - try: import seahub.settings as seahub_settings SEA_EMBEDDING_SERVER = getattr(seahub_settings, 'SEA_EMBEDDING_SERVER', '') From f3f8bc158dbe354935e52d402bfccf0f90bc1c3d Mon Sep 17 00:00:00 2001 From: cir9no <44470218+cir9no@users.noreply.github.com> Date: Wed, 17 Jul 2024 17:54:49 +0800 Subject: [PATCH 3/4] seg semantic search --- app/app.py | 6 +- seafevent_server/request_handler.py | 16 ++-- seafevent_server/seafevent_server.py | 8 +- semantic_search/config.py | 81 +++++-------------- semantic_search/db.py | 81 ------------------- semantic_search/index_store/index_manager.py | 5 +- semantic_search/index_store/models.py | 2 +- .../index_task/filename_index_updater.py | 18 +++-- .../index_task/index_task_manager.py | 43 +++++++--- .../script/repo_file_index_local.py | 2 +- semantic_search/semantic_search.py | 50 ------------ semantic_search/semantic_search_settings.py | 22 ----- semantic_search/utils/sea_embedding_api.py | 3 +- 13 files changed, 83 insertions(+), 254 deletions(-) delete mode 100644 semantic_search/db.py delete mode 100644 semantic_search/semantic_search.py delete mode 100644 semantic_search/semantic_search_settings.py diff --git a/app/app.py b/app/app.py index ec1b1898..da05a7cc 100644 --- a/app/app.py +++ b/app/app.py @@ -3,12 +3,12 @@ VirusScanner, Statistics, CountUserActivity, CountTrafficInfo, ContentScanner,\ WorkWinxinNoticeSender, FileUpdatesSender, RepoOldFileAutoDelScanner,\ DeletedFilesCountCleaner -from seafevents.semantic_search.semantic_search import SemanticSearch from seafevents.repo_metadata.index_master import RepoMetadataIndexMaster from seafevents.repo_metadata.index_worker import RepoMetadataIndexWorker from seafevents.seafevent_server.seafevent_server import SeafEventServer from seafevents.app.config import ENABLE_METADATA_MANAGEMENT, ENABLE_SEAFILE_AI +from seafevents.semantic_search.index_task.filename_index_updater import repo_filename_index_updater class App(object): @@ -40,7 +40,7 @@ def __init__(self, config, ccnet_config, seafile_config, self._index_master = RepoMetadataIndexMaster(config) self._index_worker = RepoMetadataIndexWorker(config) if ENABLE_SEAFILE_AI: - self._sem_app = SemanticSearch() + repo_filename_index_updater.init() def serve_forever(self): if self._fg_tasks_enabled: @@ -63,3 +63,5 @@ def serve_forever(self): if ENABLE_METADATA_MANAGEMENT: self._index_master.start() self._index_worker.start() + if ENABLE_SEAFILE_AI: + repo_filename_index_updater.start() diff --git a/seafevent_server/request_handler.py b/seafevent_server/request_handler.py index 15a81df7..e101c84f 100644 --- a/seafevent_server/request_handler.py +++ b/seafevent_server/request_handler.py @@ -8,7 +8,7 @@ from seafevents.seafevent_server.export_task_manager import event_export_task_manager from seafevents.semantic_search.index_task.index_task_manager import index_task_manager -from seafevents.semantic_search.semantic_search import sem_app +from seafevents.repo_data import repo_data app = Flask(__name__) logger = logging.getLogger(__name__) @@ -117,13 +117,13 @@ def library_sdoc_indexes(): if not repo_id: return {'error_msg': 'repo_id invalid.'}, 400 - commit_id = sem_app.repo_data.get_repo_head_commit(repo_id) + commit_id = repo_data.get_repo_head_commit(repo_id) if not commit_id: return {'error_msg': 'repo invalid.'}, 400 try: - is_exist = sem_app.repo_file_index.check_index(repo_id) + is_exist = index_task_manager.repo_file_index.check_index(repo_id) except Exception as e: logger.exception(e) return {'error_msg': 'Internet server error.'}, 500 @@ -137,7 +137,7 @@ def library_sdoc_indexes(): return {'task_id': task.id}, 200 try: - sem_app.index_manager.create_index_repo_db(repo_id) + index_task_manager.index_manager.create_index_repo_db(repo_id) except Exception as e: logger.exception(e) return {'error_msg': 'Internet server error.'}, 500 @@ -201,7 +201,7 @@ def library_sdoc_index(): return {'error_msg': 'repo_id invalid'}, 400 try: - index_repo = sem_app.index_manager.get_index_repo_by_repo_id(repo_id) + index_repo = index_task_manager.index_manager.get_index_repo_by_repo_id(repo_id) except Exception as e: logger.exception(e) return {'error_msg': 'Internet server error.'}, 500 @@ -216,7 +216,7 @@ def library_sdoc_index(): return {'error_msg': 'library sdoc index is running'}, 400 try: - sem_app.index_manager.delete_library_sdoc_index_by_repo_id(repo_id, sem_app.repo_file_index, sem_app.repo_status_index) + index_task_manager.index_manager.delete_library_sdoc_index_by_repo_id(repo_id, index_task_manager.repo_file_index, index_task_manager.repo_status_index) except Exception as e: logger.exception(e) return {'error_msg': 'Internet server error.'}, 500 @@ -224,7 +224,7 @@ def library_sdoc_index(): return {'success': True}, 200 elif request.method == 'PUT': - commit_id = sem_app.repo_data.get_repo_head_commit(repo_id) + commit_id = repo_data.get_repo_head_commit(repo_id) if not commit_id: return {'error_msg': 'repo invalid.'}, 400 @@ -272,7 +272,7 @@ def query_library_index_state(): return {'error_msg': 'repo_id invalid'}, 400 try: - is_exist = sem_app.index_manager.get_index_repo_by_repo_id(repo_id) + is_exist = index_task_manager.index_manager.get_index_repo_by_repo_id(repo_id) except Exception as e: logger.exception(e) return {'error_msg': 'Internet server error.'}, 500 diff --git a/seafevent_server/seafevent_server.py b/seafevent_server/seafevent_server.py index d176a3f7..0f197e1b 100644 --- a/seafevent_server/seafevent_server.py +++ b/seafevent_server/seafevent_server.py @@ -4,9 +4,7 @@ from seafevents.seafevent_server.request_handler import app as application from seafevents.seafevent_server.task_manager import task_manager from seafevents.seafevent_server.export_task_manager import event_export_task_manager -from seafevents.semantic_search.semantic_search import sem_app from seafevents.semantic_search.index_task.index_task_manager import index_task_manager -from seafevents.semantic_search.index_task.filename_index_updater import repo_filename_index_updater from seafevents.app.config import ENABLE_SEAFILE_AI @@ -24,12 +22,8 @@ def __init__(self, app, config): if ENABLE_SEAFILE_AI: # semantic search index task - sem_app.init() - index_task_manager.init(sem_app) - repo_filename_index_updater.init(sem_app) - + index_task_manager.init() index_task_manager.start() - repo_filename_index_updater.start() self._server = WSGIServer((self._host, int(self._port)), application) diff --git a/semantic_search/config.py b/semantic_search/config.py index 3156096a..024f4eb4 100644 --- a/semantic_search/config.py +++ b/semantic_search/config.py @@ -1,66 +1,27 @@ import os import logging +from seafevents.app.config import get_config logger = logging.getLogger(__name__) - -APP_NAME = 'semantic-search' - -# sections -## indexManager worker count -INDEX_MANAGER_WORKERS = 2 -INDEX_TASK_EXPIRE_TIME = 30 * 60 - -RETRIEVAL_NUM = 20 - -# embedding dimension -DIMENSION = 768 - -MODEL_VOCAB_PATH = '' -FILE_SENTENCE_LIMIT = 1000 - -THRESHOLD = 0.01 - -## seasearch -SEASEARCH_SERVER = 'http://127.0.0.1:4080' -SEASEARCH_TOKEN = '' -VECTOR_M = 256 -SHARD_NUM = 1 - -## sea-embedding -SEA_EMBEDDING_SERVER = '' -SEA_EMBEDDING_KEY = '' - - -# repo file index support file types -SUPPORT_INDEX_FILE_TYPES = [ - '.sdoc', - '.md', - '.markdown', - '.doc', - '.docx', - '.ppt', - '.pptx', - '.pdf', -] - try: - import seahub.settings as seahub_settings - SEA_EMBEDDING_SERVER = getattr(seahub_settings, 'SEA_EMBEDDING_SERVER', '') - SEA_EMBEDDING_KEY = getattr(seahub_settings, 'SEA_EMBEDDING_KEY', '') - SEASEARCH_SERVER = getattr(seahub_settings, 'SEASEARCH_SERVER', '') - SEASEARCH_TOKEN = getattr(seahub_settings, 'SEASEARCH_TOKEN', '') - MODEL_VOCAB_PATH = getattr(seahub_settings, 'MODEL_VOCAB_PATH', '') - MODEL_CACHE_DIR = getattr(seahub_settings, 'MODEL_CACHE_DIR', '') - INDEX_STORAGE_PATH = getattr(seahub_settings, 'INDEX_STORAGE_PATH', '') -except ImportError: - logger.critical("Can not import seahub settings.") - raise RuntimeError("Can not import seahub settings.") - - -try: - - if os.path.exists('/data/dev/seafevents/semantic_search/semantic_search_settings.py'): - from seafevents.semantic_search.semantic_search_settings import * -except: - pass + evtconf = os.environ['EVENTS_CONFIG_FILE'] + conf = get_config(evtconf) + sem_conf = conf['SEMANTIC_SEARCH'] + + INDEX_MANAGER_WORKERS = int(sem_conf['index_manager_workers']) + INDEX_TASK_EXPIRE_TIME = int(sem_conf['index_task_expire_time']) + RETRIEVAL_NUM = int(sem_conf['retrieval_num']) + DIMENSION = int(sem_conf['embedding_dimension']) + MODEL_VOCAB_PATH = sem_conf['embedding_model_vocab_path'] + FILE_SENTENCE_LIMIT = int(sem_conf['embedding_file_sentence_limit']) + THRESHOLD = float(sem_conf['threshold']) + SEASEARCH_SERVER = sem_conf['seasearch_server'] + SEASEARCH_TOKEN = sem_conf['seasearch_token'] + VECTOR_M = int(sem_conf['seasearch_vector_m']) + SHARD_NUM = int(sem_conf['seasearch_shard_num']) + SUPPORT_INDEX_FILE_TYPES = sem_conf['suppport_index_file_types'].split(', ') + SEA_EMBEDDING_SERVER = sem_conf['sea_embedding_server'] + SEA_EMBEDDING_KEY = sem_conf['sea_embedding_key'] +except Exception as e: + logger.warning(e) diff --git a/semantic_search/db.py b/semantic_search/db.py deleted file mode 100644 index 01b39f6f..00000000 --- a/semantic_search/db.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging -import configparser - -from urllib.parse import quote_plus - -from sqlalchemy import create_engine -from sqlalchemy.event import contains as has_event_listener, listen as add_event_listener -from sqlalchemy.exc import DisconnectionError -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import Pool - - -# base class of model classes in events.models and stats.models -class Base(DeclarativeBase): - pass - - -logger = logging.getLogger('seafevents') - - -def create_engine_from_conf(config_file): - seaf_conf = configparser.ConfigParser() - seaf_conf.read(config_file) - backend = seaf_conf.get('DATABASE', 'type') - - if backend == 'mysql': - db_server = 'localhost' - db_port = 3306 - - if seaf_conf.has_option('DATABASE', 'host'): - db_server = seaf_conf.get('DATABASE', 'host') - if seaf_conf.has_option('DATABASE', 'port'): - db_port = seaf_conf.getint('DATABASE', 'port') - db_username = seaf_conf.get('DATABASE', 'username') - db_passwd = seaf_conf.get('DATABASE', 'password') - db_name = seaf_conf.get('DATABASE', 'name') - db_url = "mysql+pymysql://%s:%s@%s:%s/%s?charset=utf8" % \ - (db_username, quote_plus(db_passwd), - db_server, db_port, db_name) - else: - logger.critical("Unknown Database backend: %s" % backend) - raise RuntimeError("Unknown Database backend: %s" % backend) - - kwargs = dict(pool_recycle=300, echo=False, echo_pool=False) - - engine = create_engine(db_url, **kwargs) - if not has_event_listener(Pool, 'checkout', ping_connection): - # We use has_event_listener to double check in case we call create_engine - # multipe times in the same process. - add_event_listener(Pool, 'checkout', ping_connection) - - return engine - -def init_db_session_class(config_file): - """Configure Session class for mysql according to the config file.""" - try: - engine = create_engine_from_conf(config_file) - except (configparser.NoOptionError, configparser.NoSectionError) as e: - logger.error(e) - raise RuntimeError("invalid config file %s", config_file) - - Session = sessionmaker(bind=engine) - return Session - -# This is used to fix the problem of "MySQL has gone away" that happens when -# mysql server is restarted or the pooled connections are closed by the mysql -# server beacause being idle for too long. -# -# See http://stackoverflow.com/a/17791117/1467959 -def ping_connection(dbapi_connection, connection_record, connection_proxy): # pylint: disable=unused-argument - cursor = dbapi_connection.cursor() - try: - cursor.execute("SELECT 1") - cursor.close() - except: - logger.info('fail to ping database server, disposing all cached connections') - connection_proxy._pool.dispose() # pylint: disable=protected-access - - # Raise DisconnectionError so the pool would create a new connection - raise DisconnectionError() diff --git a/semantic_search/index_store/index_manager.py b/semantic_search/index_store/index_manager.py index f9350459..ad34a9ec 100644 --- a/semantic_search/index_store/index_manager.py +++ b/semantic_search/index_store/index_manager.py @@ -5,9 +5,10 @@ from sqlalchemy.sql import text +from seafevents.app.config import get_config +from seafevents.db import init_db_session_class from seafevents.semantic_search import config from seafevents.semantic_search.utils.constants import ZERO_OBJ_ID, REPO_FILENAME_INDEX_PREFIX -from seafevents.semantic_search.db import init_db_session_class from seafevents.semantic_search.index_store.models import IndexRepo from seafevents.semantic_search.index_store.utils import rank_fusion, filter_hybrid_searched_files @@ -17,7 +18,7 @@ class IndexManager(): def __init__(self): self.evtconf = os.environ['EVENTS_CONFIG_FILE'] - self._db_session_class = init_db_session_class(self.evtconf) + self._db_session_class = init_db_session_class(get_config(self.evtconf)) def create_index_repo_db(self, repo_id): with self._db_session_class() as db_session: diff --git a/semantic_search/index_store/models.py b/semantic_search/index_store/models.py index 4142e59f..afdddddb 100644 --- a/semantic_search/index_store/models.py +++ b/semantic_search/index_store/models.py @@ -1,5 +1,5 @@ from sqlalchemy import Column, Integer, String, DateTime -from seafevents.semantic_search.db import Base +from seafevents.db import Base class IndexRepo(Base): diff --git a/semantic_search/index_task/filename_index_updater.py b/semantic_search/index_task/filename_index_updater.py index 8cb180b0..8eab9eb4 100644 --- a/semantic_search/index_task/filename_index_updater.py +++ b/semantic_search/index_task/filename_index_updater.py @@ -4,6 +4,13 @@ from apscheduler.triggers.cron import CronTrigger from apscheduler.schedulers.gevent import GeventScheduler +from seafevents.semantic_search.index_store.index_manager import IndexManager +from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex +from seafevents.semantic_search.index_store.repo_file_name_index import RepoFileNameIndex +from seafevents.semantic_search.utils.constants import REPO_STATUS_FILENAME_INDEX_NAME +from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI +from seafevents.repo_data import repo_data +from seafevents.semantic_search import config logger = logging.getLogger(__name__) @@ -15,11 +22,12 @@ def __init__(self): self._index_manager = None self._repo_data = None - def init(self, app): - self._repo_status_filename_index = app.repo_status_filename_index - self._repo_filename_index = app.repo_filename_index - self._index_manager = app.index_manager - self._repo_data = app.repo_data + def init(self): + self.seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + self._repo_status_filename_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) + self._repo_filename_index = RepoFileNameIndex(self.seasearch_api, repo_data) + self._index_manager = IndexManager() + self._repo_data = repo_data def start(self): RepoFilenameIndexUpdaterTimer( diff --git a/semantic_search/index_task/index_task_manager.py b/semantic_search/index_task/index_task_manager.py index 8c15df95..cfed20be 100644 --- a/semantic_search/index_task/index_task_manager.py +++ b/semantic_search/index_task/index_task_manager.py @@ -7,9 +7,16 @@ from apscheduler.triggers.cron import CronTrigger from apscheduler.schedulers.gevent import GeventScheduler +from seafevents.semantic_search.index_store.index_manager import IndexManager +from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex +from seafevents.semantic_search.index_store.repo_file_index import RepoFileIndex +from seafevents.semantic_search.index_store.repo_file_name_index import RepoFileNameIndex +from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI +from seafevents.semantic_search.utils.sea_embedding_api import SeaEmbeddingAPI +from seafevents.semantic_search.utils.constants import REPO_STATUS_FILE_INDEX_NAME +from seafevents.repo_data import repo_data from seafevents.semantic_search import config - logger = logging.getLogger(__name__) @@ -78,9 +85,19 @@ def __init__(self): } self.sched.add_job(self.clear_expired_tasks, CronTrigger(minute='*/10')) self.sched.add_job(self.cron_update_library_sdoc_indexes, CronTrigger(hour='*')) - - def init(self, app): - self.app = app + self.index_manager = None + self.repo_file_index = None + + def init(self): + self.index_manager = IndexManager() + self.seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + self.repo_data = repo_data + self.embedding_api = SeaEmbeddingAPI(config.SEA_EMBEDDING_SERVER) + # for semantic search + self.repo_status_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILE_INDEX_NAME) + self.repo_file_index = RepoFileIndex(self.seasearch_api) + # for keyword search + self.repo_filename_index = RepoFileNameIndex(self.seasearch_api, self.repo_data) def get_pending_or_running_task(self, readable_id): task = self.readable_id2task_map.get(readable_id) @@ -94,8 +111,8 @@ def add_library_sdoc_index_task(self, repo_id, commit_id): return task.id task_id = str(uuid.uuid4()) - task = IndexTask(task_id, readable_id, self.app.index_manager.create_library_sdoc_index, - (repo_id, self.app.embedding_api, self.app.repo_file_index, self.app.repo_status_index, commit_id) + task = IndexTask(task_id, readable_id, self.index_manager.create_library_sdoc_index, + (repo_id, self.embedding_api, self.repo_file_index, self.repo_status_index, commit_id) ) self.tasks_map[task_id] = task @@ -105,11 +122,11 @@ def add_library_sdoc_index_task(self, repo_id, commit_id): return task_id def keyword_search(self, query, repos, count, suffixes): - return self.app.index_manager.keyword_search(query, repos, self.app.repo_filename_index, count, suffixes) + return self.index_manager.keyword_search(query, repos, self.repo_filename_index, count, suffixes) def hybrid_search(self, query, repo, count): - return self.app.index_manager.hybrid_search(query, repo, self.app.repo_filename_index, - self.app.embedding_api, self.app.repo_file_index, count) + return self.index_manager.hybrid_search(query, repo, self.repo_filename_index, + self.embedding_api, self.repo_file_index, count) def add_update_a_library_sdoc_index_task(self, repo_id, commit_id): readable_id = repo_id @@ -119,8 +136,8 @@ def add_update_a_library_sdoc_index_task(self, repo_id, commit_id): return task.id task_id = str(uuid.uuid4()) - task = IndexTask(task_id, readable_id, self.app.index_manager.update_library_sdoc_index, - (repo_id, self.app.embedding_api, self.app.repo_file_index, self.app.repo_status_index, + task = IndexTask(task_id, readable_id, self.index_manager.update_library_sdoc_index, + (repo_id, self.embedding_api, self.repo_file_index, self.repo_status_index, commit_id) ) self.tasks_map[task_id] = task @@ -130,10 +147,10 @@ def add_update_a_library_sdoc_index_task(self, repo_id, commit_id): return task_id def update_library_sdoc_indexes(self): - index_repos = self.app.index_manager.list_index_repos() + index_repos = self.index_manager.list_index_repos() for repo in index_repos: repo_id = repo[0] - commit_id = self.app.repo_data.get_repo_head_commit(repo_id) + commit_id = self.repo_data.get_repo_head_commit(repo_id) self.add_update_a_library_sdoc_index_task(repo_id, commit_id) def cron_update_library_sdoc_indexes(self): diff --git a/semantic_search/script/repo_file_index_local.py b/semantic_search/script/repo_file_index_local.py index 16f4a08c..421eef59 100644 --- a/semantic_search/script/repo_file_index_local.py +++ b/semantic_search/script/repo_file_index_local.py @@ -151,7 +151,7 @@ def start_index_local(): repo_status_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILE_INDEX_NAME) repo_file_index = RepoFileIndex(seasearch_api) - embedding_api = SeaEmbeddingAPI(config.APP_NAME, config.SEA_EMBEDDING_SERVER) + embedding_api = SeaEmbeddingAPI(config.SEA_EMBEDDING_SERVER) workers = config.INDEX_MANAGER_WORKERS try: diff --git a/semantic_search/semantic_search.py b/semantic_search/semantic_search.py deleted file mode 100644 index d3dadbd7..00000000 --- a/semantic_search/semantic_search.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging - -from seafevents.semantic_search.index_task.index_task_manager import index_task_manager -from seafevents.semantic_search.index_task.filename_index_updater import repo_filename_index_updater -from seafevents.semantic_search.index_store.index_manager import IndexManager -from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex -from seafevents.semantic_search.index_store.repo_file_index import RepoFileIndex -from seafevents.semantic_search.index_store.repo_file_name_index import RepoFileNameIndex -from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI -from seafevents.semantic_search.utils.sea_embedding_api import SeaEmbeddingAPI -from seafevents.semantic_search.utils.constants import REPO_STATUS_FILE_INDEX_NAME, REPO_STATUS_FILENAME_INDEX_NAME -from seafevents.repo_data import repo_data -from seafevents.semantic_search import config - -logger = logging.getLogger(__name__) - -class SemanticSearch(): - def __init__(self): - self.index_manager = None - self.seasearch_api = None - self.repo_data = None - self.embedding_api = None - - # for semantic search - self.repo_status_index = None - self.repo_file_index = None - - # for keyword search - self.repo_status_filename_index = None - self.repo_filename_index = None - self.index_task_manager = None - self.repo_filename_index_updater = None - - def init(self): - self.index_manager = IndexManager() - self.seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) - self.repo_data = repo_data - self.embedding_api = SeaEmbeddingAPI(config.APP_NAME, config.SEA_EMBEDDING_SERVER) - - # for semantic search - self.repo_status_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILE_INDEX_NAME) - self.repo_file_index = RepoFileIndex(self.seasearch_api) - - # for keyword search - self.repo_status_filename_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) - self.repo_filename_index = RepoFileNameIndex(self.seasearch_api, self.repo_data) - self.index_task_manager = index_task_manager - self.repo_filename_index_updater = repo_filename_index_updater - -sem_app = SemanticSearch() diff --git a/semantic_search/semantic_search_settings.py b/semantic_search/semantic_search_settings.py deleted file mode 100644 index 9e514029..00000000 --- a/semantic_search/semantic_search_settings.py +++ /dev/null @@ -1,22 +0,0 @@ - - -# ENABLE_SYS_LOG = True - -INDEX_STORAGE_PATH = '/data/dev/static/index/' - -MODEL_CACHE_DIR = '/data/dev/static/' - -RETRIEVAL_NUM = 50 - -MODEL_VOCAB_PATH = '/data/dev/static/damo/nlp_corom_sentence-embedding_chinese-base/' - - - -## sea-embedding -SEA_EMBEDDING_SERVER = 'http://host.docker.internal:8889' -SEA_EMBEDDING_KEY = '123' - - -#seasech -SEASEARCH_SERVER = 'http://host.docker.internal:4080' -SEASEARCH_TOKEN = 'YWRtaW46Q29tcGxleHBhc3MjMTIz' diff --git a/semantic_search/utils/sea_embedding_api.py b/semantic_search/utils/sea_embedding_api.py index 002a5aa0..4bbc98a7 100644 --- a/semantic_search/utils/sea_embedding_api.py +++ b/semantic_search/utils/sea_embedding_api.py @@ -22,8 +22,7 @@ def parse_response(response): class SeaEmbeddingAPI(object): - def __init__(self, username, sea_embedding_url, time_out=180): - self.username = username + def __init__(self, sea_embedding_url, time_out=180): self.sea_embedding_url = sea_embedding_url.rstrip('/') self.time_out = time_out From 715e09e3fa60e80dd2167b16fdbb493fffc014a9 Mon Sep 17 00:00:00 2001 From: cir9no <44470218+cir9no@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:18:10 +0800 Subject: [PATCH 4/4] stash integrate seafile-ai search --- app/app.py | 6 +- seafevent_server/seafevent_server.py | 2 +- semantic_search/config.py | 13 +- semantic_search/index_store/index_manager.py | 6 +- .../index_store/repo_file_index.py | 51 ++-- .../index_store/repo_file_name_index.py | 5 +- .../index_store/repo_status_index.py | 4 - .../index_task/filename_index_updater.py | 23 +- .../index_task/index_task_manager.py | 55 +++- .../script/repo_file_index_local.py | 275 ------------------ .../script/repo_filename_index_local.py | 271 ----------------- semantic_search/utils/constants.py | 1 + semantic_search/utils/sea_embedding_api.py | 6 +- semantic_search/utils/text_splitter.py | 2 +- 14 files changed, 93 insertions(+), 627 deletions(-) delete mode 100644 semantic_search/script/repo_file_index_local.py delete mode 100644 semantic_search/script/repo_filename_index_local.py diff --git a/app/app.py b/app/app.py index da05a7cc..9a51d694 100644 --- a/app/app.py +++ b/app/app.py @@ -8,7 +8,7 @@ from seafevents.repo_metadata.index_worker import RepoMetadataIndexWorker from seafevents.seafevent_server.seafevent_server import SeafEventServer from seafevents.app.config import ENABLE_METADATA_MANAGEMENT, ENABLE_SEAFILE_AI -from seafevents.semantic_search.index_task.filename_index_updater import repo_filename_index_updater +from seafevents.semantic_search.index_task.filename_index_updater import RepoFilenameIndexUpdater class App(object): @@ -40,7 +40,7 @@ def __init__(self, config, ccnet_config, seafile_config, self._index_master = RepoMetadataIndexMaster(config) self._index_worker = RepoMetadataIndexWorker(config) if ENABLE_SEAFILE_AI: - repo_filename_index_updater.init() + self._repo_filename_index_updater = RepoFilenameIndexUpdater(config) def serve_forever(self): if self._fg_tasks_enabled: @@ -64,4 +64,4 @@ def serve_forever(self): self._index_master.start() self._index_worker.start() if ENABLE_SEAFILE_AI: - repo_filename_index_updater.start() + self._repo_filename_index_updater.start() diff --git a/seafevent_server/seafevent_server.py b/seafevent_server/seafevent_server.py index 0f197e1b..8e1babb8 100644 --- a/seafevent_server/seafevent_server.py +++ b/seafevent_server/seafevent_server.py @@ -22,7 +22,7 @@ def __init__(self, app, config): if ENABLE_SEAFILE_AI: # semantic search index task - index_task_manager.init() + index_task_manager.init(config) index_task_manager.start() self._server = WSGIServer((self._host, int(self._port)), application) diff --git a/semantic_search/config.py b/semantic_search/config.py index 024f4eb4..6325f80d 100644 --- a/semantic_search/config.py +++ b/semantic_search/config.py @@ -9,19 +9,8 @@ conf = get_config(evtconf) sem_conf = conf['SEMANTIC_SEARCH'] - INDEX_MANAGER_WORKERS = int(sem_conf['index_manager_workers']) - INDEX_TASK_EXPIRE_TIME = int(sem_conf['index_task_expire_time']) - RETRIEVAL_NUM = int(sem_conf['retrieval_num']) - DIMENSION = int(sem_conf['embedding_dimension']) MODEL_VOCAB_PATH = sem_conf['embedding_model_vocab_path'] - FILE_SENTENCE_LIMIT = int(sem_conf['embedding_file_sentence_limit']) - THRESHOLD = float(sem_conf['threshold']) - SEASEARCH_SERVER = sem_conf['seasearch_server'] - SEASEARCH_TOKEN = sem_conf['seasearch_token'] - VECTOR_M = int(sem_conf['seasearch_vector_m']) - SHARD_NUM = int(sem_conf['seasearch_shard_num']) SUPPORT_INDEX_FILE_TYPES = sem_conf['suppport_index_file_types'].split(', ') - SEA_EMBEDDING_SERVER = sem_conf['sea_embedding_server'] - SEA_EMBEDDING_KEY = sem_conf['sea_embedding_key'] + except Exception as e: logger.warning(e) diff --git a/semantic_search/index_store/index_manager.py b/semantic_search/index_store/index_manager.py index ad34a9ec..e4b5276a 100644 --- a/semantic_search/index_store/index_manager.py +++ b/semantic_search/index_store/index_manager.py @@ -7,7 +7,6 @@ from seafevents.app.config import get_config from seafevents.db import init_db_session_class -from seafevents.semantic_search import config from seafevents.semantic_search.utils.constants import ZERO_OBJ_ID, REPO_FILENAME_INDEX_PREFIX from seafevents.semantic_search.index_store.models import IndexRepo from seafevents.semantic_search.index_store.utils import rank_fusion, filter_hybrid_searched_files @@ -16,9 +15,10 @@ class IndexManager(): - def __init__(self): + def __init__(self, retrieval_num): self.evtconf = os.environ['EVENTS_CONFIG_FILE'] self._db_session_class = init_db_session_class(get_config(self.evtconf)) + self.retrieval_num = retrieval_num def create_index_repo_db(self, repo_id): with self._db_session_class() as db_session: @@ -73,7 +73,7 @@ def create_library_sdoc_index(self, repo_id, embedding_api, repo_file_index, rep logger.info('library: %s, save library file to SeaSearch success', repo_id) def search_children_in_library(self, query, repo, embedding_api, repo_file_index, count=20): - return repo_file_index.search_files(repo, config.RETRIEVAL_NUM, embedding_api, query)[:count] + return repo_file_index.search_files(repo, self.retrieval_num, embedding_api, query)[:count] def update_library_sdoc_index(self, repo_id, embedding_api, repo_file_index, repo_status_index, new_commit_id): try: diff --git a/semantic_search/index_store/repo_file_index.py b/semantic_search/index_store/repo_file_index.py index 30d15541..78d176bb 100644 --- a/semantic_search/index_store/repo_file_index.py +++ b/semantic_search/index_store/repo_file_index.py @@ -1,7 +1,6 @@ import os import logging -from seafevents.semantic_search import config from seafevents.semantic_search.index_store.utils import parse_file_to_sentences, bulk_add_sentences_to_index from seafevents.semantic_search.utils import get_library_diff_files, is_sys_dir_or_file @@ -13,31 +12,33 @@ class RepoFileIndex(object): - """ - index name is repo id - """ - mapping = { - "properties": { - "vec": { - "type": "vector", - "dims": config.DIMENSION, - "vec_index_type": "ivf_pq", - "nbits": 4, - "m": config.VECTOR_M - }, - "path": { - "type": "keyword" - }, - 'content': { - 'type': 'text' + + def __init__(self, seasearch_api, dimension, vector_m, shard_num, threshold, file_sentence_limit): + self.seasearch_api = seasearch_api + """ + index name is repo id + """ + self.mapping = { + "properties": { + "vec": { + "type": "vector", + "dims": dimension, + "vec_index_type": "ivf_pq", + "nbits": 4, + "m": vector_m + }, + "path": { + "type": "keyword" + }, + 'content': { + 'type': 'text' + } } } - } - shard_num = config.SHARD_NUM - - def __init__(self, seasearch_api): - self.seasearch_api = seasearch_api + self.shard_num = shard_num + self.threshold = threshold + self.file_sentence_limit = file_sentence_limit def create_index(self, index_name): data = { @@ -84,7 +85,7 @@ def search_files(self, repo, k, embedding_api, query): if origin_path and not path.startswith(origin_path): continue - if score < config.THRESHOLD: + if score < self.threshold: continue if searched_result.get(path): @@ -211,6 +212,6 @@ def add_files(self, index_name, files, embedding_api, commit_id): def add_file(self, index_name, file_info, commit_id, embedding_api, path): sentences = parse_file_to_sentences(index_name, file_info, commit_id) - sentences = sentences[0: config.FILE_SENTENCE_LIMIT] + sentences = sentences[0: self.file_sentence_limit] limit = int(SEASEARCH_BULK_OPETATE_LIMIT / 2) bulk_add_sentences_to_index(self.seasearch_api, embedding_api, index_name, path, sentences, limit) diff --git a/semantic_search/index_store/repo_file_name_index.py b/semantic_search/index_store/repo_file_name_index.py index cea0e4a7..42482f5c 100644 --- a/semantic_search/index_store/repo_file_name_index.py +++ b/semantic_search/index_store/repo_file_name_index.py @@ -3,7 +3,6 @@ import logging from seafevents.semantic_search.utils import get_library_diff_files, md5, is_sys_dir_or_file -from seafevents.semantic_search import config from seafevents.semantic_search.utils.constants import REPO_FILENAME_INDEX_PREFIX logger = logging.getLogger(__name__) @@ -60,11 +59,11 @@ class RepoFileNameIndex(object): } } - shard_num = config.SHARD_NUM - def __init__(self, seasearch_api, repo_data): + def __init__(self, seasearch_api, repo_data, shard_num): self.seasearch_api = seasearch_api self.repo_data = repo_data + self.shard_num = shard_num def create_index_if_missing(self, index_name): if not self.seasearch_api.check_index_mapping(index_name).get('is_exist'): diff --git a/semantic_search/index_store/repo_status_index.py b/semantic_search/index_store/repo_status_index.py index adfce0d8..8985996f 100644 --- a/semantic_search/index_store/repo_status_index.py +++ b/semantic_search/index_store/repo_status_index.py @@ -1,6 +1,3 @@ -from seafevents.semantic_search import config - - class RepoStatus(object): def __init__(self, repo_id, from_commit, to_commit): self.repo_id = repo_id @@ -41,7 +38,6 @@ class RepoStatusIndex(): }, } - shard_num = config.SHARD_NUM def __init__(self, seasearch_api, index_name): self.index_name = index_name diff --git a/semantic_search/index_task/filename_index_updater.py b/semantic_search/index_task/filename_index_updater.py index 8eab9eb4..cf50ca70 100644 --- a/semantic_search/index_task/filename_index_updater.py +++ b/semantic_search/index_task/filename_index_updater.py @@ -10,25 +10,24 @@ from seafevents.semantic_search.utils.constants import REPO_STATUS_FILENAME_INDEX_NAME from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI from seafevents.repo_data import repo_data -from seafevents.semantic_search import config logger = logging.getLogger(__name__) -class RepoFilenameIndexUpdater(): - def __init__(self): - self._repo_status_filename_index = None - self._repo_filename_index = None - self._index_manager = None - self._repo_data = None - - def init(self): - self.seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) +class RepoFilenameIndexUpdater(object): + def __init__(self, config): + self._parse_config(config) + self.seasearch_api = SeaSearchAPI(self.SEASEARCH_SERVER, self.SEASEARCH_TOKEN) self._repo_status_filename_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) - self._repo_filename_index = RepoFileNameIndex(self.seasearch_api, repo_data) + self._repo_filename_index = RepoFileNameIndex(self.seasearch_api, repo_data, self.shard_num) self._index_manager = IndexManager() self._repo_data = repo_data + def _parse_config(self, config): + self.seasearch_server = config['SEMANTIC_SEARCH']['seasearch_server'] + self.seasearch_token = config['SEMANTIC_SEARCH']['seasearch_token'] + self.shard_num = int(config['SEMANTIC_SEARCH']['seasearch_shard_num']) + def start(self): RepoFilenameIndexUpdaterTimer( self._repo_status_filename_index, self._repo_filename_index, self._index_manager, self._repo_data @@ -92,5 +91,3 @@ def run(self): logging.exception('periodical update filename index error: %s', e) sched.start() - -repo_filename_index_updater = RepoFilenameIndexUpdater() diff --git a/semantic_search/index_task/index_task_manager.py b/semantic_search/index_task/index_task_manager.py index cfed20be..62bb34bd 100644 --- a/semantic_search/index_task/index_task_manager.py +++ b/semantic_search/index_task/index_task_manager.py @@ -78,27 +78,56 @@ def __init__(self): self.readable_id2task_map = {} # {task_readable_id: task} in queue or running self.check_task_lock = Lock() # lock access to readable_id2task_map self.sched = GeventScheduler() - self.app = None - self.conf = { - 'workers': config.INDEX_MANAGER_WORKERS, - 'expire_time': config.INDEX_TASK_EXPIRE_TIME - } + self.sched.add_job(self.clear_expired_tasks, CronTrigger(minute='*/10')) self.sched.add_job(self.cron_update_library_sdoc_indexes, CronTrigger(hour='*')) - self.index_manager = None - self.repo_file_index = None - def init(self): - self.index_manager = IndexManager() - self.seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) + def init(self, config): + self._parse_config(config) + self.conf = { + 'workers': self.index_manager_workers, + 'expire_time': self.index_task_expire_time + } + self.index_manager = IndexManager(self.retrieval_num) + self.seasearch_api = SeaSearchAPI(self.seasearch_server, self.seasearch_token) self.repo_data = repo_data - self.embedding_api = SeaEmbeddingAPI(config.SEA_EMBEDDING_SERVER) + self.embedding_api = SeaEmbeddingAPI(self.sea_embedding_server, self.sea_embedding_key) # for semantic search - self.repo_status_index = RepoStatusIndex(self.seasearch_api, REPO_STATUS_FILE_INDEX_NAME) - self.repo_file_index = RepoFileIndex(self.seasearch_api) + self.repo_status_index = RepoStatusIndex( + self.seasearch_api, REPO_STATUS_FILE_INDEX_NAME + ) + self.repo_file_index = RepoFileIndex( + self.seasearch_api, + self.dimension, + self.vector_m, + self.shard_num, + self.threshold, + self.file_sentence_limit, + ) # for keyword search self.repo_filename_index = RepoFileNameIndex(self.seasearch_api, self.repo_data) + def _parse_config(self, config): + self.seasearch_server = config['SEMANTIC_SEARCH']['seasearch_server'] + self.seasearch_token = config['SEMANTIC_SEARCH']['seasearch_token'] + self.sea_embedding_server = config['SEMANTIC_SEARCH']['sea_embedding_server'] + self.sea_embedding_key = config['SEMANTIC_SEARCH']['sea_embedding_key'] + + self.index_manager_workers = int( + config['SEMANTIC_SEARCH']['index_manager_workers'] + ) + self.index_task_expire_time = int( + config['SEMANTIC_SEARCH']['index_task_expire_time'] + ) + self.retrieval_num = int(config['SEMANTIC_SEARCH']['retrieval_num']) + self.dimension = int(config['SEMANTIC_SEARCH']['embedding_dimension']) + self.vector_m = int(config['SEMANTIC_SEARCH']['seasearch_vector_m']) + self.shard_num = int(config['SEMANTIC_SEARCH']['seasearch_shard_num']) + self.threshold = float(config['SEMANTIC_SEARCH']['threshold']) + self.file_sentence_limit = int( + config['SEMANTIC_SEARCH']['embedding_file_sentence_limit'] + ) + def get_pending_or_running_task(self, readable_id): task = self.readable_id2task_map.get(readable_id) return task diff --git a/semantic_search/script/repo_file_index_local.py b/semantic_search/script/repo_file_index_local.py deleted file mode 100644 index 421eef59..00000000 --- a/semantic_search/script/repo_file_index_local.py +++ /dev/null @@ -1,275 +0,0 @@ -import os -import sys -import time -import queue -import logging -import argparse -import threading - -from seafobj import commit_mgr, fs_mgr, block_mgr - -import config -from seafevents.semantic_search.utils import init_logging -from seafevents.repo_data import repo_data -from seafevents.semantic_search.index_store.index_manager import IndexManager -from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI -from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex -from seafevents.semantic_search.index_store.repo_file_index import RepoFileIndex -from seafevents.semantic_search.utils.constants import REPO_STATUS_FILE_INDEX_NAME -from seafevents.semantic_search.utils.sea_embedding_api import SeaEmbeddingAPI - - -MAX_ERRORS_ALLOWED = 1000 -logger = logging.getLogger('semantic_search') - -UPDATE_FILE_LOCK = os.path.join(os.path.dirname(__file__), 'update.lock') -lockfile = None -NO_TASKS = False - - -class RepoFileIndexLocal(object): - """ Independent update repo file index. - """ - def __init__(self, index_manager, repo_status_index, repo_file_index, embedding_api, repo_data, workers=3): - self.index_manager = index_manager - self.repo_status_index = repo_status_index - self.repo_file_index = repo_file_index - self.embedding_api = embedding_api - self.repo_data = repo_data - self.error_counter = 0 - self.worker_list = [] - self.workers = workers - - def clear_worker(self): - for th in self.worker_list: - th.join() - logger.info("All worker threads has stopped.") - - def run(self): - time_start = time.time() - repos_queue = queue.Queue(0) - for i in range(self.workers): - thread_name = "worker" + str(i) - logger.info("starting %s worker threads for repo file indexing" - % thread_name) - t = threading.Thread(target=self.thread_task, args=(repos_queue, ), name=thread_name) - t.start() - self.worker_list.append(t) - - start, per_size = 0, 1000 - need_deleted_index_repos = [] - while True: - global NO_TASKS - try: - index_repos = list(self.index_manager.get_index_repos_by_size(start, per_size)) - except Exception as e: - logger.error("Error: %s" % e) - NO_TASKS = True - self.clear_worker() - break - else: - if len(index_repos) == 0: - NO_TASKS = True - break - - for index_repo in index_repos: - repo_id = index_repo[0] - commit_id = self.repo_data.get_repo_head_commit(repo_id) - if not commit_id: - # repo has deleted, delete repo index - need_deleted_index_repos.append(repo_id) - continue - repos_queue.put((repo_id, commit_id)) - - start += per_size - - self.clear_worker() - logger.info("repo file index updated, total time %s seconds" % str(time.time() - time_start)) - try: - self.clear_deleted_repo(need_deleted_index_repos) - except Exception as e: - logger.exception('Delete Repo Error: %s' % e) - self.incr_error() - - def thread_task(self, repos_queue): - while True: - try: - queue_data = repos_queue.get(False) - except queue.Empty: - if NO_TASKS: - logger.debug( - "Queue is empty, %s worker threads stop" - % (threading.currentThread().getName()) - ) - break - else: - time.sleep(2) - else: - repo_id = queue_data[0] - commit_id = queue_data[1] - try: - self.index_manager.create_library_sdoc_index(repo_id, self.embedding_api, self.repo_file_index, self.repo_status_index, commit_id) - except Exception as e: - logger.exception('Repo file index error: %s, repo_id: %s' % (e, repo_id), exc_info=True) - self.incr_error() - - logger.info( - "%s worker updated at %s time" - % (threading.currentThread().getName(), - time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))) - ) - logger.info( - "%s worker get %s error" - % (threading.currentThread().getName(), - str(self.error_counter)) - ) - - def clear_deleted_repo(self, repos): - logger.info("start to clear deleted repo") - logger.info("%d repos need to be deleted." % len(repos)) - - for repo_id in repos: - self.delete_repo(repo_id) - logger.info('Repo %s has been deleted from index.' % repo_id) - logger.info("deleted repo has been cleared") - - def incr_error(self): - self.error_counter += 1 - - def delete_repo(self, repo_id): - if len(repo_id) != 36: - return - self.index_manager.delete_index_repo_db(repo_id) - - -def start_index_local(): - if not check_concurrent_update(): - return - - seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) - index_manager = IndexManager() - repo_status_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILE_INDEX_NAME) - repo_file_index = RepoFileIndex(seasearch_api) - - embedding_api = SeaEmbeddingAPI(config.SEA_EMBEDDING_SERVER) - workers = config.INDEX_MANAGER_WORKERS - - try: - index_local = RepoFileIndexLocal(index_manager, repo_status_index, repo_file_index, embedding_api, repo_data, workers) - except Exception as e: - logger.error("Index repo file process init error: %s." % e) - return - - logger.info("Index repo file process initialized.") - index_local.run() - - logger.info('\n\nRepo file index updated, statistic report:\n') - logger.info('[commit read] %s', commit_mgr.read_count()) - logger.info('[dir read] %s', fs_mgr.dir_read_count()) - logger.info('[file read] %s', fs_mgr.file_read_count()) - logger.info('[block read] %s', block_mgr.read_count()) - - -def delete_indices(): - seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) - repo_status_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILE_INDEX_NAME) - repo_file_index = RepoFileIndex(seasearch_api) - index_manager = IndexManager() - - start, per_size = 0, 1000 - while True: - index_repos = list(index_manager.get_index_repos_by_size(start, per_size)) - - if len(index_repos) == 0: - break - - for index_repo in index_repos: - repo_file_index.delete_index_by_index_name(index_repo[0]) - start += per_size - - repo_status_index.delete_index_by_index_name() - - -def main(): - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(title='subcommands', description='') - - parser.add_argument( - '--logfile', - default=sys.stdout, - type=argparse.FileType('a'), - help='log file') - - parser.add_argument( - '--loglevel', - default='info', - help='log level') - - # update index - parser_update = subparsers.add_parser('update', help='update repo file index') - parser_update.set_defaults(func=start_index_local) - - # clear - parser_clear = subparsers.add_parser('clear', help='clear all repo file index') - parser_clear.set_defaults(func=delete_indices) - - if len(sys.argv) == 1: - print(parser.format_help()) - return - - args = parser.parse_args() - init_logging(args) - - logger.info('storage: using ' + commit_mgr.get_backend_name()) - - args.func() - - -def do_lock(fn): - if os.name == 'nt': - return do_lock_win32(fn) - else: - return do_lock_linux(fn) - - -def do_lock_win32(fn): - import ctypes - - CreateFileW = ctypes.windll.kernel32.CreateFileW - GENERIC_WRITE = 0x40000000 - OPEN_ALWAYS = 4 - - def lock_file(path): - lock_file_handle = CreateFileW(path, GENERIC_WRITE, 0, None, OPEN_ALWAYS, 0, None) - - return lock_file_handle - - global lockfile - - lockfile = lock_file(fn) - - return lockfile != -1 - - -def do_lock_linux(fn): - from seafevents.semantic_search import portalocker - global lockfile - lockfile = open(fn, 'w') - try: - portalocker.lock(lockfile, portalocker.LOCK_NB | portalocker.LOCK_EX) - return True - except portalocker.LockException: - return False - - -def check_concurrent_update(): - """Use a lock file to ensure only one task can be running""" - if not do_lock(UPDATE_FILE_LOCK): - logger.error('another index task is running, quit now') - return False - - return True - - -if __name__ == "__main__": - main() diff --git a/semantic_search/script/repo_filename_index_local.py b/semantic_search/script/repo_filename_index_local.py deleted file mode 100644 index 538529cd..00000000 --- a/semantic_search/script/repo_filename_index_local.py +++ /dev/null @@ -1,271 +0,0 @@ -import os -import sys -import time -import queue -import logging -import argparse -import threading - -from seafobj import commit_mgr, fs_mgr, block_mgr -import config -from seafevents.semantic_search.utils import init_logging -from seafevents.repo_data import repo_data -from seafevents.semantic_search.index_store.index_manager import IndexManager -from seafevents.semantic_search.utils.seasearch_api import SeaSearchAPI -from seafevents.semantic_search.index_store.repo_status_index import RepoStatusIndex -from seafevents.semantic_search.utils.constants import REPO_STATUS_FILENAME_INDEX_NAME, REPO_FILENAME_INDEX_PREFIX -from seafevents.semantic_search.index_store.repo_file_name_index import RepoFileNameIndex - -MAX_ERRORS_ALLOWED = 1000 -logger = logging.getLogger('semantic_search') - -UPDATE_FILE_LOCK = os.path.join(os.path.dirname(__file__), 'update.lock') -lockfile = None -NO_TASKS = False - - -class RepoFileNameIndexLocal(object): - """ Independent update repo file name index. - """ - def __init__(self, index_manager, repo_status_filename_index, repo_filename_index, repo_data, workers=3): - self.index_manager = index_manager - self.repo_status_filename_index = repo_status_filename_index - self.repo_filename_index = repo_filename_index - self.repo_data = repo_data - self.error_counter = 0 - self.worker_list = [] - self.workers = workers - - def clear_worker(self): - for th in self.worker_list: - th.join() - logger.info("All worker threads has stopped.") - - def run(self): - time_start = time.time() - repos_queue = queue.Queue(0) - for i in range(self.workers): - thread_name = "worker" + str(i) - logger.info("starting %s worker threads for repo filename indexing" - % thread_name) - t = threading.Thread(target=self.thread_task, args=(repos_queue, ), name=thread_name) - t.start() - self.worker_list.append(t) - - start, per_size = 0, 1000 - repos = {} - while True: - global NO_TASKS - try: - repo_commits = self.repo_data.get_repo_id_commit_id(start, per_size) - except Exception as e: - logger.error("Error: %s" % e) - NO_TASKS = True - self.clear_worker() - return - else: - if len(repo_commits) == 0: - NO_TASKS = True - break - for repo_id, commit_id in repo_commits.items(): - repos_queue.put((repo_id, commit_id)) - repos[repo_id] = commit_id - start += per_size - - self.clear_worker() - logger.info("repo filename index updated, total time %s seconds" % str(time.time() - time_start)) - try: - self.clear_deleted_repo(list(repos.keys())) - except Exception as e: - logger.exception('Delete Repo Error: %s' % e) - self.incr_error() - - def thread_task(self, repos_queue): - while True: - try: - queue_data = repos_queue.get(False) - except queue.Empty: - if NO_TASKS: - logger.debug( - "Queue is empty, %s worker threads stop" - % (threading.currentThread().getName()) - ) - break - else: - time.sleep(2) - else: - repo_id = queue_data[0] - commit_id = queue_data[1] - try: - self.index_manager.update_library_filename_index(repo_id, commit_id, self.repo_filename_index, self.repo_status_filename_index) - except Exception as e: - logger.exception('Repo filename index error: %s, repo_id: %s' % (e, repo_id), exc_info=True) - self.incr_error() - - logger.info( - "%s worker updated at %s time" - % (threading.currentThread().getName(), - time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))) - ) - logger.info( - "%s worker get %s error" - % (threading.currentThread().getName(), - str(self.error_counter)) - ) - - def clear_deleted_repo(self, repos): - logger.info("start to clear deleted repo") - repo_all = [e.get('repo_id') for e in self.repo_status_filename_index.get_all_repos_from_index()] - - repo_deleted = set(repo_all) - set(repos) - logger.info("%d repos need to be deleted." % len(repo_deleted)) - - for repo_id in repo_deleted: - self.delete_repo(repo_id) - logger.info('Repo %s has been deleted from index.' % repo_id) - logger.info("deleted repo has been cleared") - - def incr_error(self): - self.error_counter += 1 - - def delete_repo(self, repo_id): - if len(repo_id) != 36: - return - self.index_manager.delete_repo_filename_index(repo_id, self.repo_filename_index, self.repo_status_filename_index) - - -def start_index_local(): - if not check_concurrent_update(): - return - - index_manager = IndexManager() - seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) - repo_status_filename_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) - - repo_filename_index = RepoFileNameIndex(seasearch_api, repo_data) - - workers = config.INDEX_MANAGER_WORKERS - - try: - index_local = RepoFileNameIndexLocal(index_manager, repo_status_filename_index, repo_filename_index,repo_data, workers) - except Exception as e: - logger.error("Index repo filename process init error: %s." % e) - return - - logger.info("Index repo filename process initialized.") - index_local.run() - - logger.info('\n\nRepo filename index updated, statistic report:\n') - logger.info('[commit read] %s', commit_mgr.read_count()) - logger.info('[dir read] %s', fs_mgr.dir_read_count()) - logger.info('[file read] %s', fs_mgr.file_read_count()) - logger.info('[block read] %s', block_mgr.read_count()) - - -def delete_indices(): - seasearch_api = SeaSearchAPI(config.SEASEARCH_SERVER, config.SEASEARCH_TOKEN) - repo_status_filename_index = RepoStatusIndex(seasearch_api, REPO_STATUS_FILENAME_INDEX_NAME) - repo_filename_index = RepoFileNameIndex(seasearch_api, repo_data) - - start, count = 0, 1000 - while True: - try: - repo_commits = repo_data.get_repo_id_commit_id(start, count) - except Exception as e: - logger.error("Error: %s" % e) - return - start += 1000 - - if len(repo_commits) == 0: - break - - for repo_id, commit_id in repo_commits.items(): - repo_filename_index_name = REPO_FILENAME_INDEX_PREFIX + repo_id - repo_filename_index.delete_index_by_index_name(repo_filename_index_name) - - repo_status_filename_index.delete_index_by_index_name() - - -def main(): - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(title='subcommands', description='') - - parser.add_argument( - '--logfile', - default=sys.stdout, - type=argparse.FileType('a'), - help='log file') - - parser.add_argument( - '--loglevel', - default='info', - help='log level') - - # update index - parser_update = subparsers.add_parser('update', help='update seafile repo filename index') - parser_update.set_defaults(func=start_index_local) - - # clear - parser_clear = subparsers.add_parser('clear', help='clear all repo filename index') - parser_clear.set_defaults(func=delete_indices) - - if len(sys.argv) == 1: - print(parser.format_help()) - return - - args = parser.parse_args() - init_logging(args) - - logger.info('storage: using ' + commit_mgr.get_backend_name()) - - args.func() - - -def do_lock(fn): - if os.name == 'nt': - return do_lock_win32(fn) - else: - return do_lock_linux(fn) - - -def do_lock_win32(fn): - import ctypes - - CreateFileW = ctypes.windll.kernel32.CreateFileW - GENERIC_WRITE = 0x40000000 - OPEN_ALWAYS = 4 - - def lock_file(path): - lock_file_handle = CreateFileW(path, GENERIC_WRITE, 0, None, OPEN_ALWAYS, 0, None) - - return lock_file_handle - - global lockfile - - lockfile = lock_file(fn) - - return lockfile != -1 - - -def do_lock_linux(fn): - from seafevents.semantic_search import portalocker - global lockfile - lockfile = open(fn, 'w') - try: - portalocker.lock(lockfile, portalocker.LOCK_NB | portalocker.LOCK_EX) - return True - except portalocker.LockException: - return False - - -def check_concurrent_update(): - """Use a lock file to ensure only one task can be running""" - if not do_lock(UPDATE_FILE_LOCK): - logger.error('another index task is running, quit now') - return False - - return True - - -if __name__ == "__main__": - main() diff --git a/semantic_search/utils/constants.py b/semantic_search/utils/constants.py index 272bf49c..e15216bb 100644 --- a/semantic_search/utils/constants.py +++ b/semantic_search/utils/constants.py @@ -3,3 +3,4 @@ REPO_STATUS_FILE_INDEX_NAME = 'repo_status_file' REPO_STATUS_FILENAME_INDEX_NAME = 'repo_status_filename' REPO_FILENAME_INDEX_PREFIX = 'filename_' +MODEL_VOCAB_PATH = '/data/dev/static/damo/nlp_corom_sentence-embedding_chinese-base/' \ No newline at end of file diff --git a/semantic_search/utils/sea_embedding_api.py b/semantic_search/utils/sea_embedding_api.py index 4bbc98a7..2d6608e2 100644 --- a/semantic_search/utils/sea_embedding_api.py +++ b/semantic_search/utils/sea_embedding_api.py @@ -4,7 +4,6 @@ import time import json -from seafevents.semantic_search.config import SEA_EMBEDDING_KEY logger = logging.getLogger(__name__) @@ -22,13 +21,14 @@ def parse_response(response): class SeaEmbeddingAPI(object): - def __init__(self, sea_embedding_url, time_out=180): + def __init__(self, sea_embedding_url, sea_embedding_key, time_out=180): self.sea_embedding_url = sea_embedding_url.rstrip('/') self.time_out = time_out + self.sea_embedding_key = sea_embedding_key def gen_headers(self): payload = {'exp': int(time.time()) + 300, } - token = jwt.encode(payload, SEA_EMBEDDING_KEY, algorithm='HS256') + token = jwt.encode(payload, self.sea_embedding_key, algorithm='HS256') return {"Authorization": "Token %s" % token} def embeddings(self, input): diff --git a/semantic_search/utils/text_splitter.py b/semantic_search/utils/text_splitter.py index fb12ae17..27bfeb54 100644 --- a/semantic_search/utils/text_splitter.py +++ b/semantic_search/utils/text_splitter.py @@ -2,7 +2,7 @@ import logging from transformers import AutoTokenizer -from seafevents.semantic_search.config import MODEL_VOCAB_PATH +from seafevents.semantic_search.utils.constants import MODEL_VOCAB_PATH logger = logging.getLogger(__name__)