Skip to content

Commit

Permalink
Merge pull request #1590 from h2oai/cleanup_stopping
Browse files Browse the repository at this point in the history
Clean-up stopping to avoid hard-coded things for llama-3 as it was fixed 11 days ago.
  • Loading branch information
pseudotensor authored Apr 30, 2024
2 parents 832ad2d + 917d612 commit 39b2a1c
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 46 deletions.
15 changes: 8 additions & 7 deletions src/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@

from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt, \
openai_gpts, get_vllm_extra_dict, anthropic_gpts, google_gpts, mistralai_gpts, groq_gpts, \
gradio_to_llm, history_for_llm, is_gradio_vision_model, is_json_model, get_use_chat_template, apply_chat_template
gradio_to_llm, history_for_llm, is_gradio_vision_model, is_json_model, apply_chat_template
from stopping import get_stopping
from prompter_utils import get_use_chat_template

langchain_actions = [x.value for x in list(LangChainAction)]

Expand Down Expand Up @@ -3263,12 +3264,12 @@ def get_model(
if base_model in non_hf_types:
from gpt4all_llm import get_model_tokenizer_gpt4all
model, tokenizer_llamacpp, device = get_model_tokenizer_gpt4all(base_model,
n_jobs=n_jobs,
gpu_id=gpu_id,
n_gpus=n_gpus,
max_seq_len=max_seq_len,
llamacpp_dict=llamacpp_dict,
llamacpp_path=llamacpp_path)
n_jobs=n_jobs,
gpu_id=gpu_id,
n_gpus=n_gpus,
max_seq_len=max_seq_len,
llamacpp_dict=llamacpp_dict,
llamacpp_path=llamacpp_path)
# give chance to use tokenizer_base_model
if tokenizer is None:
tokenizer = tokenizer_llamacpp
Expand Down
2 changes: 1 addition & 1 deletion src/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7980,7 +7980,7 @@ def get_chain(query=None,
context=context,
iinput=iinput,
system_prompt=system_prompt)
if external_handle_chat_conversation or prompter.prompt_type in ['template', 'unknown']:
if external_handle_chat_conversation or prompter.prompt_type in [template_prompt_type, unknown_prompt_type]:
# should already have attribute, checking sanity
assert hasattr(llm, 'chat_conversation')
llm_kwargs.update(chat_conversation=history_to_use_final)
Expand Down
38 changes: 10 additions & 28 deletions src/prompter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import ast
import time
import os

# also supports imports from this file from other files
from enums import PromptType, gpt_token_mapping, \
anthropic_mapping, google_mapping, mistralai_mapping, groq_mapping, openai_supports_json_mode, noop_prompt_type, \
unknown_prompt_type, template_prompt_type, user_prompt_for_fake_system_prompt0
unknown_prompt_type, user_prompt_for_fake_system_prompt0, template_prompt_type
from src.prompter_utils import get_use_chat_template
from src.stopping import update_terminate_responses
from src.utils import get_gradio_tmp

non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
Expand Down Expand Up @@ -1676,18 +1679,6 @@ def inject_chatsep(prompt_type, prompt, chat_sep=None):
return prompt


def get_use_chat_template(tokenizer, prompt_type=None):
if tokenizer is None:
return False
use_chat_template = prompt_type in [None, '', unknown_prompt_type, template_prompt_type] and \
(hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)
return use_chat_template


class Prompter(object):
def __init__(self, prompt_type, prompt_dict, debug=False, stream_output=False, repeat_penalty=False,
allowed_repeat_line_length=10, system_prompt=None, tokenizer=None, verbose=False):
Expand All @@ -1709,20 +1700,11 @@ def __init__(self, prompt_type, prompt_dict, debug=False, stream_output=False, r
system_prompt=system_prompt)
self.use_chat_template = False
self.tokenizer = tokenizer
if tokenizer is not None:
self.use_chat_template = get_use_chat_template(tokenizer, prompt_type=prompt_type)
if self.use_chat_template:
# add terminations
if self.terminate_response is None:
self.terminate_response = []
# like in stopping.py
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
self.terminate_response.extend([tokenizer.eos_token])
if '<|eot_id|>' in tokenizer.added_tokens_encoder:
self.terminate_response.extend(['<|eot_id|>'])
if '<|im_end|>' in tokenizer.added_tokens_encoder:
self.terminate_response.extend(['<|im_end|>'])

if self.terminate_response is None:
self.terminate_response = []
self.use_chat_template = get_use_chat_template(tokenizer, prompt_type=prompt_type)
self.terminate_response = update_terminate_responses(self.terminate_response,
tokenizer=tokenizer)
self.pre_response = self.PreResponse
self.verbose = verbose

Expand All @@ -1743,7 +1725,7 @@ def generate_prompt(self, data_point, reduced=False, context_from_history=None,
In which case we need to put promptA at very front to recover correct behavior
:return:
"""
if self.prompt_type in ['template', 'unknown']:
if self.prompt_type in [template_prompt_type, unknown_prompt_type]:
assert self.use_chat_template
assert self.tokenizer is not None
from src.gen import apply_chat_template
Expand Down
13 changes: 13 additions & 0 deletions src/prompter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from src.enums import unknown_prompt_type, template_prompt_type


def get_use_chat_template(tokenizer, prompt_type=None):
if tokenizer is None:
return False
use_chat_template = prompt_type in [None, '', unknown_prompt_type, template_prompt_type] and \
(hasattr(tokenizer, 'chat_template') and
tokenizer.chat_template not in [None, ''] or
hasattr(tokenizer, 'default_chat_template') and
tokenizer.default_chat_template not in [None, '']
)
return use_chat_template
39 changes: 30 additions & 9 deletions src/stopping.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,37 @@
import time

import torch
from transformers import StoppingCriteria, StoppingCriteriaList
from transformers import StoppingCriteria, StoppingCriteriaList, GenerationConfig

from enums import PromptType, t5_type
from src.prompter_utils import get_use_chat_template


def update_terminate_responses(terminate_response, tokenizer=None):
if terminate_response is None:
terminate_response = []
if tokenizer is not None:
# e.g. for dbrx
if '<|im_end|>' in tokenizer.added_tokens_encoder:
terminate_response.extend(['<|im_end|>'])
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
if isinstance(tokenizer.eos_token, str):
terminate_response.extend([tokenizer.eos_token])
elif isinstance(tokenizer.eos_token, list):
terminate_response.extend(tokenizer.eos_token)

if hasattr(tokenizer, 'name_or_path'):
reverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
generate_eos_token_id = GenerationConfig.from_pretrained(tokenizer.name_or_path).eos_token_id
if isinstance(generate_eos_token_id, list):
for eos_token_id in generate_eos_token_id:
terminate_response.extend([reverse_vocab[eos_token_id]])
else:
terminate_response.extend([reverse_vocab[generate_eos_token_id]])
terminate_response_tmp = terminate_response.copy()
terminate_response.clear()
[terminate_response.append(x) for x in terminate_response_tmp if x not in terminate_response]
return terminate_response


class StoppingCriteriaSub(StoppingCriteria):
Expand Down Expand Up @@ -132,14 +160,7 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
encounters += [1] * len(stop)
handle_newlines += [False] * len(stop)

# e.g. for llama-3
# https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
if '<|eot_id|>' in tokenizer.added_tokens_encoder:
stop_words.extend(['<|eot_id|>'])
if '<|im_end|>' in tokenizer.added_tokens_encoder:
stop_words.extend(['<|im_end|>'])
if hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
stop_words.extend([tokenizer.eos_token])
stop_words = update_terminate_responses(stop_words, tokenizer=tokenizer)

# get stop tokens
stop_words_ids = [
Expand Down
2 changes: 1 addition & 1 deletion src/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "f957748399a19591580f0f99ee13b85f99e3f9fb"
__version__ = "832ad2d4a6b1431105785045a6b218a8451591f9"

0 comments on commit 39b2a1c

Please sign in to comment.