Skip to content

Commit

Permalink
feat: 分离任务
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 committed Aug 21, 2024
1 parent b596b69 commit 66ce583
Show file tree
Hide file tree
Showing 66 changed files with 2,074 additions and 293 deletions.
2 changes: 1 addition & 1 deletion apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from functools import reduce
from typing import Dict, List

from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core import cache, validators
from django.core import signing
Expand Down Expand Up @@ -46,7 +47,6 @@
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from setting.serializers.provider_serializers import ModelSerializer
from smartdoc.conf import PROJECT_DIR
from django.conf import settings

chat_cache = cache.caches['chat_cache']

Expand Down
8 changes: 5 additions & 3 deletions apps/application/serializers/chat_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id, \
get_embedding_model_id_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from embedding.task import embedding_by_paragraph
from smartdoc.conf import PROJECT_DIR

chat_cache = caches['chat_cache']
Expand Down Expand Up @@ -516,9 +518,9 @@ def is_valid(self, *, raise_exception=False):

@staticmethod
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
model_id = get_embedding_model_id_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
embedding_by_paragraph(paragraph_id, model_id)
return chat_record

@post(post_function=post_embedding_paragraph)
Expand Down
Empty file.
25 changes: 17 additions & 8 deletions apps/common/config/embedding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,35 @@
@date:2023/10/23 16:03
@desc:
"""
import threading
import time

from common.cache.mem_cache import MemCache

lock = threading.Lock()


class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()

@staticmethod
def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
# 获取锁
lock.acquire()
try:
model_instance = ModelManage.cache.get(_id)
if model_instance is None or not model_instance.is_cache_model():
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
finally:
# 释放锁
lock.release()

@staticmethod
def clear_timeout_cache():
Expand Down
1 change: 0 additions & 1 deletion apps/common/event/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@


def run():
listener_manage.ListenerManagement().run()
QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error})
QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR,
meta={'message': "下载程序被中断,请重试"})
94 changes: 15 additions & 79 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,20 @@
from typing import List

import django.db.models
from blinker import signal
from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings

from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy
from common.event.common import embedding_poxy
from common.util.file_util import get_file_content
from common.util.fork import ForkManage, Fork
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from embedding.models import SourceType
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR

max_kb_error = logging.getLogger("max_kb_error")
max_kb = logging.getLogger("max_kb")
max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__)


class SyncWebDatasetArgs:
Expand Down Expand Up @@ -70,23 +68,6 @@ def __init__(self, paragraph_id_list: List[str], target_document_id: str, target


class ListenerManagement:
embedding_by_problem_signal = signal("embedding_by_problem")
embedding_by_paragraph_signal = signal("embedding_by_paragraph")
embedding_by_dataset_signal = signal("embedding_by_dataset")
embedding_by_document_signal = signal("embedding_by_document")
delete_embedding_by_document_signal = signal("delete_embedding_by_document")
delete_embedding_by_document_list_signal = signal("delete_embedding_by_document_list")
delete_embedding_by_dataset_signal = signal("delete_embedding_by_dataset")
delete_embedding_by_paragraph_signal = signal("delete_embedding_by_paragraph")
delete_embedding_by_source_signal = signal("delete_embedding_by_source")
enable_embedding_by_paragraph_signal = signal('enable_embedding_by_paragraph')
disable_embedding_by_paragraph_signal = signal('disable_embedding_by_paragraph')
init_embedding_model_signal = signal('init_embedding_model')
sync_web_dataset_signal = signal('sync_web_dataset')
sync_web_document_signal = signal('sync_web_document')
update_problem_signal = signal('update_problem')
delete_embedding_by_source_ids_signal = signal('delete_embedding_by_source_ids')
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")

@staticmethod
def embedding_by_problem(args, embedding_model: Embeddings):
Expand Down Expand Up @@ -160,7 +141,6 @@ def is_save_function():
max_kb.info(f'结束--->向量化段落:{paragraph_id}')

@staticmethod
@embedding_poxy
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
向量化文档
Expand Down Expand Up @@ -227,7 +207,7 @@ def delete_embedding_by_document(document_id):

@staticmethod
def delete_embedding_by_document_list(document_id_list: List[str]):
VectorStore.get_embedding_vector().delete_bu_document_id_list(document_id_list)
VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list)

@staticmethod
def delete_embedding_by_dataset(dataset_id):
Expand All @@ -249,25 +229,6 @@ def disable_embedding_by_paragraph(paragraph_id):
def enable_embedding_by_paragraph(paragraph_id):
VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True})

@staticmethod
@poxy
def sync_web_document(args: SyncWebDocumentArgs):
for source_url in args.source_url_list:
result = Fork(base_fork_url=source_url, selector_list=args.selector.split(' ')).fork()
args.handler(source_url, args.selector, result)

@staticmethod
@poxy
def sync_web_dataset(args: SyncWebDatasetArgs):
if try_lock('sync_web_dataset' + args.lock_key):
try:
ForkManage(args.url, args.selector.split(" ") if args.selector is not None else []).fork(2, set(),
args.handler)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
finally:
un_lock('sync_web_dataset' + args.lock_key)

@staticmethod
def update_problem(args: UpdateProblemArgs):
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id)
Expand All @@ -281,6 +242,9 @@ def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
else:
# 删除向量数据
ListenerManagement.delete_embedding_by_paragraph_ids(args.paragraph_id_list)
# 向量数据
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)

Expand All @@ -306,38 +270,10 @@ def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)

def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
# 添加向量 根据段落id
ListenerManagement.embedding_by_paragraph_signal.connect(self.embedding_by_paragraph)
# 添加向量 根据知识库id
ListenerManagement.embedding_by_dataset_signal.connect(
self.embedding_by_dataset)
# 添加向量 根据文档id
ListenerManagement.embedding_by_document_signal.connect(
self.embedding_by_document)
# 删除 向量 根据文档
ListenerManagement.delete_embedding_by_document_signal.connect(self.delete_embedding_by_document)
# 删除 向量 根据文档id列表
ListenerManagement.delete_embedding_by_document_list_signal.connect(self.delete_embedding_by_document_list)
# 删除 向量 根据知识库id
ListenerManagement.delete_embedding_by_dataset_signal.connect(self.delete_embedding_by_dataset)
# 删除向量 根据段落id
ListenerManagement.delete_embedding_by_paragraph_signal.connect(
self.delete_embedding_by_paragraph)
# 删除向量 根据资源id
ListenerManagement.delete_embedding_by_source_signal.connect(self.delete_embedding_by_source)
# 禁用段落
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)

# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档
ListenerManagement.sync_web_document_signal.connect(self.sync_web_document)
# 更新问题向量
ListenerManagement.update_problem_signal.connect(self.update_problem)
ListenerManagement.delete_embedding_by_source_ids_signal.connect(self.delete_embedding_by_source_ids)
ListenerManagement.delete_embedding_by_dataset_id_list_signal.connect(self.delete_embedding_by_dataset_id_list)
@staticmethod
def hit_test(query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: Embeddings):
return VectorStore.get_embedding_vector().hit_test(query_text, dataset_id, exclude_document_id_list, top_number,
similarity, search_mode, embedding)
18 changes: 12 additions & 6 deletions apps/common/job/client_access_num_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from django_apscheduler.jobstores import DjangoJobStore

from application.models.api_key_model import ApplicationPublicAccessClient
from common.lock.impl.file_lock import FileLock

scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()


def client_access_num_reset_job():
Expand All @@ -25,9 +27,13 @@ def client_access_num_reset_job():


def run():
scheduler.start()
access_num_reset = scheduler.get_job(job_id='access_num_reset')
if access_num_reset is not None:
access_num_reset.remove()
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
id='access_num_reset')
if lock.try_lock('client_access_num_reset_job', 30 * 30):
try:
scheduler.start()
access_num_reset = scheduler.get_job(job_id='access_num_reset')
if access_num_reset is not None:
access_num_reset.remove()
scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0',
id='access_num_reset')
finally:
lock.un_lock('client_access_num_reset_job')
20 changes: 20 additions & 0 deletions apps/common/lock/base_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: base_lock.py
@date:2024/8/20 10:33
@desc:
"""

from abc import ABC, abstractmethod


class BaseLock(ABC):
@abstractmethod
def try_lock(self, key, timeout):
pass

@abstractmethod
def un_lock(self, key):
pass
77 changes: 77 additions & 0 deletions apps/common/lock/impl/file_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: file_lock.py
@date:2024/8/20 10:48
@desc:
"""
import errno
import hashlib
import os
import time

import six

from common.lock.base_lock import BaseLock
from smartdoc.const import PROJECT_DIR


def key_to_lock_name(key):
"""
Combine part of a key with its hash to prevent very long filenames
"""
MAX_LENGTH = 50
key_hash = hashlib.md5(six.b(key)).hexdigest()
lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash
return lock_name


class FileLock(BaseLock):
"""
File locking backend.
"""

def __init__(self, settings=None):
if settings is None:
settings = {}
self.location = settings.get('location')
if self.location is None:
self.location = os.path.join(PROJECT_DIR, 'data', 'lock')
try:
os.makedirs(self.location)
except OSError as error:
# Directory exists?
if error.errno != errno.EEXIST:
# Re-raise unexpected OSError
raise

def _get_lock_path(self, key):
lock_name = key_to_lock_name(key)
return os.path.join(self.location, lock_name)

def try_lock(self, key, timeout):
lock_path = self._get_lock_path(key)
try:
# 创建锁文件,如果没创建成功则拿不到
fd = os.open(lock_path, os.O_CREAT | os.O_EXCL)
except OSError as error:
if error.errno == errno.EEXIST:
# File already exists, check its modification time
mtime = os.path.getmtime(lock_path)
ttl = mtime + timeout - time.time()
if ttl > 0:
return False
else:
# 如果超时时间已到,直接上锁成功继续执行
os.utime(lock_path, None)
return True
else:
return False
else:
os.close(fd)
return True

def un_lock(self, key):
lock_path = self._get_lock_path(key)
os.remove(lock_path)
Empty file.
Empty file.
Loading

0 comments on commit 66ce583

Please sign in to comment.