Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: 知识库支持上传csv和excel #1083

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions apps/common/handle/base_parse_table_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# coding=utf-8
"""
@project: maxkb
@Author:虎
@file: base_parse_qa_handle.py
@date:2024/5/21 14:56
@desc:
"""
from abc import ABC, abstractmethod


class BaseParseTableHandle(ABC):
@abstractmethod
def support(self, file, get_buffer):
pass

@abstractmethod
def handle(self, file, get_buffer):
pass
34 changes: 34 additions & 0 deletions apps/common/handle/impl/table/csv_parse_table_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# coding=utf-8
import logging

from charset_normalizer import detect

from common.handle.base_parse_table_handle import BaseParseTableHandle

max_kb = logging.getLogger("max_kb")


class CsvSplitHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith(".csv"):
return True
return False

def handle(self, file, get_buffer):
buffer = get_buffer(file)
try:
content = buffer.decode(detect(buffer)['encoding'])
except BaseException as e:
max_kb.error(f'csv split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]

csv_model = content.split('\n')
paragraphs = []
# 第一行为标题
title = csv_model[0].split(',')
for row in csv_model[1:]:
line = '; '.join([f'{key}:{value}' for key, value in zip(title, row.split(','))])
paragraphs.append({'title': '', 'content': line})

return [{'name': file.name, 'paragraphs': paragraphs}]
49 changes: 49 additions & 0 deletions apps/common/handle/impl/table/excel_parse_table_handle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# coding=utf-8
import io
import logging

from openpyxl import load_workbook

from common.handle.base_parse_table_handle import BaseParseTableHandle

max_kb = logging.getLogger("max_kb")


class ExcelSplitHandle(BaseParseTableHandle):
def support(self, file, get_buffer):
file_name: str = file.name.lower()
if file_name.endswith('.xls') or file_name.endswith('.xlsx'):
return True
return False

def handle(self, file, get_buffer):
buffer = get_buffer(file)
try:
wb = load_workbook(io.BytesIO(buffer))
result = []
for sheetname in wb.sheetnames:
paragraphs = []
ws = wb[sheetname]
rows = list(ws.rows)
if not rows: continue
ti = list(rows[0])
for r in list(rows[1:]):
title = []
l = []
for i, c in enumerate(r):
if not c.value:
continue
t = str(ti[i].value) if i < len(ti) else ""
title.append(t)
t += (": " if t else "") + str(c.value)
l.append(t)
l = "; ".join(l)
if sheetname.lower().find("sheet") < 0:
l += " ——" + sheetname
paragraphs.append({'title': '', 'content': l})
result.append({'name': sheetname, 'paragraphs': paragraphs})

except BaseException as e:
max_kb.error(f'excel split handle error: {e}')
return [{'name': file.name, 'paragraphs': []}]
return result
43 changes: 43 additions & 0 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle
from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle
from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle
from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle
from common.handle.impl.table.excel_parse_table_handle import ExcelSplitHandle
from common.handle.impl.text_split_handle import TextSplitHandle
from common.mixins.api_mixin import ApiMixin
from common.util.common import post, flat_map
Expand All @@ -51,6 +53,7 @@
from smartdoc.conf import PROJECT_DIR

parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()]
parse_table_handle_list = [CsvSplitHandle(), ExcelSplitHandle()]


class FileBufferHandle:
Expand Down Expand Up @@ -152,6 +155,13 @@ class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer):
error_messages=ErrMessage.file("文件")))


class DocumentInstanceTableSerializer(ApiMixin, serializers.Serializer):
file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
error_messages=ErrMessage.file("文件")))


class DocumentSerializers(ApiMixin, serializers.Serializer):
class Export(ApiMixin, serializers.Serializer):
type = serializers.CharField(required=True, validators=[
Expand Down Expand Up @@ -187,6 +197,23 @@ def export(self, with_valid=True):
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})

def table_export(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)

if self.data.get('type') == 'csv':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.csv'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv',
'Content-Disposition': 'attachment; filename="csv_template.csv"'})
elif self.data.get('type') == 'excel':
file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.xlsx'), "rb")
content = file.read()
file.close()
return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel',
'Content-Disposition': 'attachment; filename="excel_template.xlsx"'})

class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True,
error_messages=ErrMessage.char(
Expand Down Expand Up @@ -633,6 +660,14 @@ def parse_qa_file(file):
return parse_qa_handle.handle(file, get_buffer)
raise AppApiException(500, '不支持的文件格式')

@staticmethod
def parse_table_file(file):
get_buffer = FileBufferHandle().get_buffer
for parse_table_handle in parse_table_handle_list:
if parse_table_handle.support(file, get_buffer):
return parse_table_handle.handle(file, get_buffer)
raise AppApiException(500, '不支持的文件格式')

def save_qa(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True)
Expand All @@ -641,6 +676,14 @@ def save_qa(self, instance: Dict, with_valid=True):
document_list = flat_map([self.parse_qa_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)

def save_table(self, instance: Dict, with_valid=True):
if with_valid:
DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
file_list = instance.get('file_list')
document_list = flat_map([self.parse_table_file(file) for file in file_list])
return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list)

@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):
Expand Down
13 changes: 13 additions & 0 deletions apps/dataset/template/MaxKB表格模板.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
职务,报销类型,一线城市报销标准(元),二线城市报销标准(元),三线城市报销标准(元)
普通员工,住宿费,500,400,300
部门主管,住宿费,600,500,400
部门总监,住宿费,700,600,500
区域总经理,住宿费,800,700,600
普通员工,伙食费,50,40,30
部门主管,伙食费,50,40,30
部门总监,伙食费,50,40,30
区域总经理,伙食费,50,40,30
普通员工,交通费,50,40,30
部门主管,交通费,50,40,30
部门总监,交通费,50,40,30
区域总经理,交通费,50,40,30
Binary file added apps/dataset/template/MaxKB表格模板.xlsx
Binary file not shown.
2 changes: 2 additions & 0 deletions apps/dataset/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
path('dataset/<str:dataset_id>/hit_test', views.Dataset.HitTest.as_view()),
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/document/template/export', views.Template.as_view()),
path('dataset/document/table_template/export', views.TableTemplate.as_view()),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
path('dataset/<str:dataset_id>/document/qa', views.QaDocument.as_view()),
path('dataset/<str:dataset_id>/document/table', views.TableDocument.as_view()),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
Expand Down
29 changes: 29 additions & 0 deletions apps/dataset/views/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ class Template(APIView):
def get(self, request: Request):
return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True)

class TableTemplate(APIView):
authentication_classes = [TokenAuth]

@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取表格模版",
operation_id="获取表格模版",
manual_parameters=DocumentSerializers.Export.get_request_params_api(),
tags=["知识库/文档"])
def get(self, request: Request):
return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).table_export(with_valid=True)


class WebDocument(APIView):
authentication_classes = [TokenAuth]
Expand Down Expand Up @@ -71,6 +82,24 @@ def post(self, request: Request, dataset_id: str):
{'file_list': request.FILES.getlist('file')},
with_valid=True))

class TableDocument(APIView):
authentication_classes = [TokenAuth]
parser_classes = [MultiPartParser]

@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="导入表格并创建文档",
operation_id="导入表格并创建文档",
manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(),
responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def post(self, request: Request, dataset_id: str):
return result.success(
DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_table(
{'file_list': request.FILES.getlist('file')},
with_valid=True))

class Document(APIView):
authentication_classes = [TokenAuth]
Expand Down
27 changes: 27 additions & 0 deletions ui/src/api/document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,19 @@ const postQADocument: (
return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading)
}

/**
* 导入表格
* @param 参数
* file
*/
const postTableDocument: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/document/table`, data, undefined, loading)
}

/**
* 批量迁移文档
* @param 参数 dataset_id,target_dataset_id,
Expand Down Expand Up @@ -256,6 +269,18 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref<boolean>)
return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading)
}

/**
* 获得table模版
* @param 参数 fileName,type,
*/
const exportTableTemplate: (fileName: string, type: string, loading?: Ref<boolean>) => void = (
fileName,
type,
loading
) => {
return exportExcel(fileName, `${prefix}/document/table_template/export`, { type }, loading)
}

/**
* 导出文档
* @param document_name 文档名称
Expand Down Expand Up @@ -295,6 +320,8 @@ export default {
putMigrateMulDocument,
batchEditHitHandling,
exportQATemplate,
exportTableTemplate,
postQADocument,
postTableDocument,
exportDocument
}
1 change: 1 addition & 0 deletions ui/src/utils/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ export function fileType(name: string) {
*/
const typeList: any = {
txt: ['txt', 'pdf', 'docx', 'csv', 'md', 'html', 'PDF'],
table: ['xlsx', 'xls', 'csv'],
QA: ['xlsx', 'csv', 'xls']
}

Expand Down
15 changes: 15 additions & 0 deletions ui/src/views/dataset/UploadDocumentDataset.vue
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,21 @@ async function next() {
router.push({ path: `/dataset/${id}/document` })
})
}
} else if (documentsType.value === 'table') {
let fd = new FormData()
documentsFiles.value.forEach((item: any) => {
if (item?.raw) {
fd.append('file', item?.raw)
}
})
if (id) {
// table文档上传
documentApi.postTableDocument(id as string, fd, loading).then((res) => {
MsgSuccess('提交成功')
clearStore()
router.push({ path: `/dataset/${id}/document` })
})
}
} else {
if (active.value++ > 2) active.value = 0
}
Expand Down
Loading
Loading