Skip to content

Commit

Permalink
feat: 模型管理支持向量模型,知识库可以关联向量模型
Browse files Browse the repository at this point in the history
feat:  模型管理支持向量模型,知识库可以关联向量模型
  • Loading branch information
wangdan-fit2cloud authored Jul 19, 2024
2 parents bee11ee + 43fc376 commit d3d09b1
Show file tree
Hide file tree
Showing 74 changed files with 1,562 additions and 544 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class InstanceSerializer(serializers.Serializer):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))

def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
Expand All @@ -56,6 +57,7 @@ def _run(self, manage: PipelineManage):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
Expand All @@ -67,6 +69,7 @@ def execute(self, problem_text: str, dataset_id_list: list[str], exclude_documen
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:param user_id 用户id
:return: 段落列表
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,48 @@

from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from dataset.models import Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR


def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model


def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id


class BaseSearchDatasetStep(ISearchDatasetStep):

def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
Expand Down Expand Up @@ -101,7 +127,7 @@ def get_details(self, manage, **kwargs):
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name,
'model_name': self.context.get('model_name'),
'message_tokens': 0,
'answer_tokens': 0,
'cost': 0
Expand Down
1 change: 1 addition & 0 deletions apps/application/flow/i_step_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):

client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))

user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,50 @@

from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import EmbeddingModel, VectorStore
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph
from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR


def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model


def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id


def get_none_result(question):
return NodeResult(
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
'directly_return': ''}, {})


class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
self.context['question'] = question
embedding_model = EmbeddingModel.get_embedding_model()
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in
Expand All @@ -37,7 +67,7 @@ def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
Expand Down
19 changes: 19 additions & 0 deletions apps/application/migrations/0010_alter_chatrecord_details.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Generated by Django 4.2.13 on 2024-07-15 15:52

import application.models.application
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('application', '0009_application_type_application_work_flow_and_more'),
]

operations = [
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
),
]
13 changes: 8 additions & 5 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
Expand All @@ -36,7 +36,7 @@
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
Expand Down Expand Up @@ -415,12 +415,13 @@ def hit_test(self):
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
Expand Down Expand Up @@ -522,12 +523,14 @@ def is_valid(self, *, raise_exception=False):
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
raise AppApiException(500, '不存在的应用id')

def list_model(self, with_valid=True):
def list_model(self, model_type=None, with_valid=True):
if with_valid:
self.is_valid()
if model_type is None:
model_type = "LLM"
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
return ModelSerializer.Query(
data={'user_id': application.user_id}).list(
data={'user_id': application.user_id, 'model_type': model_type}).list(
with_valid=True)

def delete(self, with_valid=True):
Expand Down
8 changes: 6 additions & 2 deletions apps/application/serializers/chat_message_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def to_base_pipeline_manage_params(self):
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}'}
'value': '{question}',
},
'user_id': self.application.user_id

}

Expand Down Expand Up @@ -221,11 +223,13 @@ def chat_work_flow(self, chat_info: ChatInfo):
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
user_id = chat_info.application.user_id
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
're_chat': re_chat,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
r = work_flow_manage.run()
return r

Expand Down
24 changes: 13 additions & 11 deletions apps/application/serializers/chat_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
from common.config.embedding_config import ModelManage
from common.constants.permission_constants import RoleConstants
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
from common.event import ListenerManagement
Expand All @@ -39,8 +40,10 @@
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
from setting.models_provider import get_model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR

Expand Down Expand Up @@ -241,12 +244,7 @@ def open_simple(self, application):
application_id=application_id)]
chat_model = None
if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)

chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
Expand All @@ -259,6 +257,7 @@ def open_simple(self, application):

class OpenWorkFlowChat(serializers.Serializer):
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))

def open(self):
self.is_valid(raise_exception=True)
Expand All @@ -269,7 +268,8 @@ def open(self):
dataset_setting={},
model_setting={},
problem_optimization=None,
type=ApplicationTypeChoices.WORK_FLOW
type=ApplicationTypeChoices.WORK_FLOW,
user_id=self.data.get('user_id')
)
work_flow_version = WorkFlowVersion(work_flow=work_flow)
chat_cache.set(chat_id,
Expand Down Expand Up @@ -332,7 +332,8 @@ def open(self):
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'))
problem_optimization=self.data.get('problem_optimization'),
user_id=user_id)
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in
Expand Down Expand Up @@ -533,9 +534,10 @@ def is_valid(self, *, raise_exception=False):
raise AppApiException(500, "文档id不正确")

@staticmethod
def post_embedding_paragraph(chat_record, paragraph_id):
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
return chat_record

@post(post_function=post_embedding_paragraph)
Expand Down Expand Up @@ -573,7 +575,7 @@ def improve(self, instance: Dict, with_valid=True):
chat_record.improve_paragraph_id_list.append(paragraph.id)
# 添加标注
chat_record.save()
return ChatRecordSerializerModel(chat_record).data, paragraph.id
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id

class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
Expand Down
14 changes: 14 additions & 0 deletions apps/application/swagger_api/application_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ def get_response_body_api():
}
)

class Model(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='模型类型'),
]

class ApiKey(ApiMixin):
@staticmethod
def get_request_params_api():
Expand Down
4 changes: 2 additions & 2 deletions apps/application/views/application_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class Model(APIView):
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表",
tags=["应用"],
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
manual_parameters=ApplicationApi.Model.get_request_params_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
Expand All @@ -185,7 +185,7 @@ def get(self, request: Request, application_id: str):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id,
'user_id': request.user.id}).list_model())
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))

class Profile(APIView):
authentication_classes = [TokenAuth]
Expand Down
4 changes: 4 additions & 0 deletions apps/common/cache/mem_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ def clear_by_application_id(self, application_id):
delete_keys.append(key)
for key in delete_keys:
self._delete(key)

def clear_timeout_data(self):
for key in self._cache.keys():
self.get(key)
Loading

0 comments on commit d3d09b1

Please sign in to comment.