Skip to content

Commit

Permalink
feat: 语音识别和语音合成
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruibin committed Sep 3, 2024
1 parent da6c1ce commit 3941367
Show file tree
Hide file tree
Showing 17 changed files with 1,172 additions and 8 deletions.
2 changes: 2 additions & 0 deletions apps/setting/models_provider/base_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def encryption(message: str):
class ModelTypeConst(Enum):
LLM = {'code': 'LLM', 'message': '大语言模型'}
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}


class ModelInfo:
Expand Down
14 changes: 14 additions & 0 deletions apps/setting/models_provider/impl/base_stt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod

from pydantic import BaseModel


class BaseSpeechToText(BaseModel):
@abstractmethod
def check_auth(self):
pass

@abstractmethod
def speech_to_text(self, audio_file):
pass
14 changes: 14 additions & 0 deletions apps/setting/models_provider/impl/base_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod

from pydantic import BaseModel


class BaseTextToSpeech(BaseModel):
@abstractmethod
def check_auth(self):
pass

@abstractmethod
def text_to_speech(self, text):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# coding=utf-8
from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}

def get_model_params_setting_form(self, model_name):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Dict

from openai import OpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_stt import BaseSpeechToText


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText):
openai_api_base: str
openai_api_key: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.openai_api_key = kwargs.get('api_key')
self.openai_api_base = kwargs.get('api_base')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAISpeechToText(
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)

def check_auth(self):
client = OpenAI(
base_url=self.openai_api_base,
api_key=self.openai_api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)

def speech_to_text(self, audio_file):
client = OpenAI(
base_url=self.openai_api_base,
api_key=self.openai_api_key
)
return client.audio.transcriptions.create(
model=self.model_name,
language="zh",
file=audio_file
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from typing import Dict

from openai import OpenAI

from common.config.tokenizer_manage_config import TokenizerManage
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_tts import BaseTextToSpeech


def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
return tokenizer.encode(text)


class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
openai_api_base: str
openai_api_key: str

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.openai_api_key = kwargs.get('api_key')
self.openai_api_base = kwargs.get('api_base')

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {}
if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None:
optional_params['max_tokens'] = model_kwargs['max_tokens']
if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None:
optional_params['temperature'] = model_kwargs['temperature']
return OpenAITextToSpeech(
api_base=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
**optional_params,
)

def check_auth(self):
client = OpenAI(
base_url=self.openai_api_base,
api_key=self.openai_api_key
)
response_list = client.models.with_raw_response.list()
# print(response_list)

def text_to_speech(self, text):
client = OpenAI(
base_url=self.openai_api_base,
api_key=self.openai_api_key
)
with client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice="alloy",
input="建国是个大聪明",
) as response:
# response.content
pass
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText
from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech
from smartdoc.conf import PROJECT_DIR

openai_llm_model_credential = OpenAILLMModelCredential()
openai_stt_model_credential = OpenAISTTModelCredential()
model_info_list = [
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
Expand Down Expand Up @@ -58,7 +62,13 @@
OpenAIChatModel),
ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel)
OpenAIChatModel),
ModelInfo('whisper-1', '',
ModelTypeConst.STT, openai_stt_model_credential,
OpenAISpeechToText),
ModelInfo('tts-1', '',
ModelTypeConst.TTS, openai_stt_model_credential,
OpenAITextToSpeech)
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding=utf-8

from typing import Dict

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode


class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential):
volcanic_api_url = forms.TextInputField('API 域名', required=True)
volcanic_app_id = forms.TextInputField('App ID', required=True)
volcanic_token = forms.PasswordInputField('Token', required=True)
volcanic_cluster = forms.TextInputField('Cluster', required=True)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')

for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))}

def get_model_params_setting_form(self, model_name):
pass
Loading

0 comments on commit 3941367

Please sign in to comment.