Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 分离任务 #1001

Merged
merged 1 commit into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import reduce
from typing import Dict, List

from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core import cache, validators
from django.core import signing
Expand Down Expand Up @@ -46,7 +47,6 @@
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from django.conf import settings

chat_cache = cache.caches['chat_cache']

Expand Down
8 changes: 5 additions & 3 deletions apps/application/serializers/chat_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \
get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from embedding.task import embedding_by_paragraph
from smartdoc.conf import PROJECT_DIR

chat_cache = caches['chat_cache']
Expand Down Expand Up @@ -516,9 +518,9 @@ def is_valid(self, *, raise_exception=False):

@staticmethod
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
embedding_by_paragraph(paragraph_id, model_id)
return chat_record

@post(post_function=post_embedding_paragraph)
Expand Down
Empty file.
25 changes: 17 additions & 8 deletions apps/common/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,35 @@
@date:2023/10/23 16:03
@desc:
"""
import threading
import time

from common.cache.mem_cache import MemCache

lock = threading.Lock()


class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()

@staticmethod
def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
# 获取锁
lock.acquire()
try:
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
finally:
# 释放锁
lock.release()

@staticmethod
def clear_timeout_cache():
Expand Down
1 change: 0 additions & 1 deletion apps/common/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


def run():
listener_manage.ListenerManagement().run()
QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error})
QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR,
meta={'message': "下载程序被中断,请重试"})
94 changes: 15 additions & 79 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
from typing import List

import django.db.models
from blinker import signal
from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings

from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy
from common.event.common import embedding_poxy
from common.util.file_util import get_file_content
from common.util.fork import ForkManage, Fork
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from embedding.models import SourceType
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR

max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__)


class SyncWebDatasetArgs:
Expand Down Expand Up @@ -70,23 +68,6 @@ def __init__(self, paragraph_id_list: List[str], target_document_id: str, target


class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
embedding_by_dataset_signal = signal("embedding_by_dataset")
embedding_by_document_signal = signal("embedding_by_document")
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
delete_embedding_by_document_list_signal = signal("delete_embedding_by_document_list")
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
sync_web_document_signal = signal('sync_web_document')
update_problem_signal = signal('update_problem')
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")

@staticmethod
def embedding_by_problem(args, embedding_model: Embeddings):
Expand Down Expand Up @@ -160,7 +141,6 @@ def is_save_function():
max_kb.info(f'结束--->向量化段落:{paragraph_id}')

@staticmethod
@embedding_poxy
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
向量化文档
Expand Down Expand Up @@ -227,7 +207,7 @@ def delete_embedding_by_document(document_id):

@staticmethod
def delete_embedding_by_document_list(document_id_list: List[str]):
VectorStore.get_embedding_vector().delete_bu_document_id_list(document_id_list)
VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list)

@staticmethod
def delete_embedding_by_dataset(dataset_id):
Expand All @@ -249,25 +229,6 @@ def disable_embedding_by_paragraph(paragraph_id):
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})

@staticmethod
@poxy
def sync_web_document(args: SyncWebDocumentArgs):
for source_url in args.source_url_list:
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
args.handler(source_url, args.selector, result)

@staticmethod
@poxy
def sync_web_dataset(args: SyncWebDatasetArgs):
if try_lock('sync_web_dataset' + args.lock_key):
try:
ForkManage(args.url, args.selector.split(" ") if args.selector is not None else []).fork(2, set(),
args.handler)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
finally:
un_lock('sync_web_dataset' + args.lock_key)

@staticmethod
def update_problem(args: UpdateProblemArgs):
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
Expand All @@ -281,6 +242,9 @@ def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
else:
# 删除向量数据
ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list)
# 向量数据
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)

Expand All @@ -306,38 +270,10 @@ def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)

def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
# 添加向量 根据段落id
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
# 添加向量 根据知识库id
ListenerManagement.embedding_by_dataset_signal.connect(
self.embedding_by_dataset)
# 添加向量 根据文档id
ListenerManagement.embedding_by_document_signal.connect(
self.embedding_by_document)
# 删除 向量 根据文档
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
# 删除 向量 根据文档id列表
ListenerManagement.delete_embedding_by_document_list_signal.connect(self.delete_embedding_by_document_list)
# 删除 向量 根据知识库id
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
# 删除向量 根据段落id
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
self.delete_embedding_by_paragraph)
# 删除向量 根据资源id
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
# 禁用段落
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)

# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
# 更新问题向量
ListenerManagement.update_problem_signal.connect(self.update_problem)
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)
ListenerManagement.delete_embedding_by_dataset_id_list_signal.connect(self.delete_embedding_by_dataset_id_list)
@staticmethod
def hit_test(query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
return VectorStore.get_embedding_vector().hit_test(query_text, dataset_id, exclude_document_id_list, top_number,
similarity, search_mode, embedding)
18 changes: 12 additions & 6 deletions apps/common/job/client_access_num_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from django_apscheduler.jobstores import DjangoJobStore

from application.models.api_key_model import ApplicationPublicAccessClient
from common.lock.impl.file_lock import FileLock

scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()


def client_access_num_reset_job():
Expand All @@ -25,9 +27,13 @@ def client_access_num_reset_job():


def run():
scheduler.start()
access_num_reset = scheduler.get_job(job_id='access_num_reset')
if access_num_reset is not None:
access_num_reset.remove()
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
id='access_num_reset')
if lock.try_lock('client_access_num_reset_job', 30 * 30):
try:
scheduler.start()
access_num_reset = scheduler.get_job(job_id='access_num_reset')
if access_num_reset is not None:
access_num_reset.remove()
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
id='access_num_reset')
finally:
lock.un_lock('client_access_num_reset_job')
20 changes: 20 additions & 0 deletions apps/common/lock/base_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: base_lock.py
@date:2024/8/20 10:33
@desc:
"""

from abc import ABC, abstractmethod


class BaseLock(ABC):
@abstractmethod
def try_lock(self, key, timeout):
pass

@abstractmethod
def un_lock(self, key):
pass
77 changes: 77 additions & 0 deletions apps/common/lock/impl/file_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: file_lock.py
@date:2024/8/20 10:48
@desc:
"""
import errno
import hashlib
import os
import time

import six

from common.lock.base_lock import BaseLock
from smartdoc.const import PROJECT_DIR


def key_to_lock_name(key):
"""
Combine part of a key with its hash to prevent very long filenames
"""
MAX_LENGTH = 50
key_hash = hashlib.md5(six.b(key)).hexdigest()
lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash
return lock_name


class FileLock(BaseLock):
"""
File locking backend.
"""

def __init__(self, settings=None):
if settings is None:
settings = {}
self.location = settings.get('location')
if self.location is None:
self.location = os.path.join(PROJECT_DIR, 'data', 'lock')
try:
os.makedirs(self.location)
except OSError as error:
# Directory exists?
if error.errno != errno.EEXIST:
# Re-raise unexpected OSError
raise

def _get_lock_path(self, key):
lock_name = key_to_lock_name(key)
return os.path.join(self.location, lock_name)

def try_lock(self, key, timeout):
lock_path = self._get_lock_path(key)
try:
# 创建锁文件,如果没创建成功则拿不到
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
except OSError as error:
if error.errno == errno.EEXIST:
# File already exists, check its modification time
mtime = os.path.getmtime(lock_path)
ttl = mtime + timeout - time.time()
if ttl > 0:
return False
else:
# 如果超时时间已到,直接上锁成功继续执行
os.utime(lock_path, None)
return True
else:
return False
else:
os.close(fd)
return True

def un_lock(self, key):
lock_path = self._get_lock_path(key)
os.remove(lock_path)
Empty file.
Empty file.
Loading
Loading