Skip to content

Commit

Permalink
feat(llm): drop default_system_prompt (#1385)
Browse files Browse the repository at this point in the history
As discussed on Discord, the decision has been made to remove the system prompts by default, to better segregate the API and the UI usages.

A concurrent PR (#1353) is enabling the dynamic setting of a system prompt in the UI.

Therefore, if UI users want to use a custom system prompt, they can specify one directly in the UI.
If the API users want to use a custom prompt, they can pass it directly into their messages that they are passing to the API.

In the highlight of the two use case above, it becomes clear that default system_prompt does not need to exist.
  • Loading branch information
lopagela authored Dec 8, 2023
1 parent f235c50 commit a3ed14c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 96 deletions.
5 changes: 1 addition & 4 deletions private_gpt/components/llm/llm_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,7 @@ def __init__(self, settings: Settings) -> None:
case "local":
from llama_index.llms import LlamaCPP

prompt_style_cls = get_prompt_style(settings.local.prompt_style)
prompt_style = prompt_style_cls(
default_system_prompt=settings.local.default_system_prompt
)
prompt_style = get_prompt_style(settings.local.prompt_style)

self.llm = LlamaCPP(
model_path=str(models_path / settings.local.llm_hf_model_file),
Expand Down
65 changes: 13 additions & 52 deletions private_gpt/components/llm/prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from llama_index.llms import ChatMessage, MessageRole
from llama_index.llms.llama_utils import (
DEFAULT_SYSTEM_PROMPT,
completion_to_prompt,
messages_to_prompt,
)
Expand All @@ -29,7 +28,6 @@ class AbstractPromptStyle(abc.ABC):
series of messages into a prompt.
"""

@abc.abstractmethod
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)

Expand All @@ -52,15 +50,6 @@ def completion_to_prompt(self, completion: str) -> str:
return prompt


class AbstractPromptStyleWithSystemPrompt(AbstractPromptStyle, abc.ABC):
_DEFAULT_SYSTEM_PROMPT = DEFAULT_SYSTEM_PROMPT

def __init__(self, default_system_prompt: str | None) -> None:
super().__init__()
logger.debug("Got default_system_prompt='%s'", default_system_prompt)
self.default_system_prompt = default_system_prompt


class DefaultPromptStyle(AbstractPromptStyle):
"""Default prompt style that uses the defaults from llama_utils.
Expand All @@ -83,7 +72,7 @@ def _completion_to_prompt(self, completion: str) -> str:
return ""


class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
class Llama2PromptStyle(AbstractPromptStyle):
"""Simple prompt style that just uses the default llama_utils functions.
It transforms the sequence of messages into a prompt that should look like:
Expand All @@ -94,18 +83,14 @@ class Llama2PromptStyle(AbstractPromptStyleWithSystemPrompt):
```
"""

def __init__(self, default_system_prompt: str | None = None) -> None:
# If no system prompt is given, the default one of the implementation is used.
super().__init__(default_system_prompt=default_system_prompt)

def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return messages_to_prompt(messages, self.default_system_prompt)
return messages_to_prompt(messages)

def _completion_to_prompt(self, completion: str) -> str:
return completion_to_prompt(completion, self.default_system_prompt)
return completion_to_prompt(completion)


class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
class TagPromptStyle(AbstractPromptStyle):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
It transforms the sequence of messages into a prompt that should look like:
Expand All @@ -119,37 +104,8 @@ class TagPromptStyle(AbstractPromptStyleWithSystemPrompt):
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
"""

def __init__(self, default_system_prompt: str | None = None) -> None:
# We have to define a default system prompt here as the LLM will not
# use the default llama_utils functions.
default_system_prompt = default_system_prompt or self._DEFAULT_SYSTEM_PROMPT
super().__init__(default_system_prompt)
self.system_prompt: str = default_system_prompt

def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
messages = list(messages)
if messages[0].role != MessageRole.SYSTEM:
logger.info(
"Adding system_promt='%s' to the given messages as there are none given in the session",
self.system_prompt,
)
messages = [
ChatMessage(content=self.system_prompt, role=MessageRole.SYSTEM),
*messages,
]
return self._format_messages_to_prompt(messages)

def _completion_to_prompt(self, completion: str) -> str:
return (
f"<|system|>: {self.system_prompt.strip()}\n"
f"<|user|>: {completion.strip()}\n"
"<|assistant|>: "
)

@staticmethod
def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
"""Format message to prompt with `<|ROLE|>: MSG` style."""
assert messages[0].role == MessageRole.SYSTEM
prompt = ""
for message in messages:
role = message.role
Expand All @@ -161,19 +117,24 @@ def _format_messages_to_prompt(messages: list[ChatMessage]) -> str:
prompt += "<|assistant|>: "
return prompt

def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)


def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag"] | None
) -> type[AbstractPromptStyle]:
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
return DefaultPromptStyle
return DefaultPromptStyle()
elif prompt_style == "llama2":
return Llama2PromptStyle
return Llama2PromptStyle()
elif prompt_style == "tag":
return TagPromptStyle
return TagPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'")
9 changes: 0 additions & 9 deletions private_gpt/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,6 @@ class LocalSettings(BaseModel):
"`llama2` is the historic behaviour. `default` might work better with your custom models."
),
)
default_system_prompt: str | None = Field(
None,
description=(
"The default system prompt to use for the chat engine. "
"If none is given - use the default system prompt (from the llama_index). "
"Please note that the default prompt might not be the same for all prompt styles. "
"Also note that this is only used if the first message is not a system message. "
),
)


class EmbeddingSettings(BaseModel):
Expand Down
34 changes: 3 additions & 31 deletions tests/test_prompt_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
],
)
def test_get_prompt_style_success(prompt_style, expected_prompt_style):
assert get_prompt_style(prompt_style) == expected_prompt_style
assert isinstance(get_prompt_style(prompt_style), expected_prompt_style)


def test_get_prompt_style_failure():
Expand All @@ -45,20 +45,7 @@ def test_tag_prompt_style_format():


def test_tag_prompt_style_format_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
prompt_style = TagPromptStyle(default_system_prompt=system_prompt)
messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]

expected_prompt = (
f"<|system|>: {system_prompt}\n"
"<|user|>: Hello, how are you doing?\n"
"<|assistant|>: "
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt

prompt_style = TagPromptStyle()
messages = [
ChatMessage(
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
Expand Down Expand Up @@ -94,22 +81,7 @@ def test_llama2_prompt_style_format():


def test_llama2_prompt_style_with_system_prompt():
system_prompt = "This is a system prompt from configuration."
prompt_style = Llama2PromptStyle(default_system_prompt=system_prompt)
messages = [
ChatMessage(content="Hello, how are you doing?", role=MessageRole.USER),
]

expected_prompt = (
"<s> [INST] <<SYS>>\n"
f" {system_prompt} \n"
"<</SYS>>\n"
"\n"
" Hello, how are you doing? [/INST]"
)

assert prompt_style.messages_to_prompt(messages) == expected_prompt

prompt_style = Llama2PromptStyle()
messages = [
ChatMessage(
content="FOO BAR Custom sys prompt from messages.", role=MessageRole.SYSTEM
Expand Down

0 comments on commit a3ed14c

Please sign in to comment.