Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Give default context to help chatbot #100

Merged
merged 1 commit into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)

prompt = context
prompt = context if not reduced else ''

if input and promptA:
prompt += f"""{promptA}"""
Expand Down
32 changes: 25 additions & 7 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def main(
gradio_avoid_processing_markdown: bool = False,
chat: bool = True,
chat_history: int = 4096, # character length of chat context/history
chat_context: bool = False, # use default context if human_bot
stream_output: bool = True,
show_examples: bool = None,
verbose: bool = False,
Expand Down Expand Up @@ -182,7 +183,7 @@ def main(
assert not chat, "No gradio must use chat=False, uses nochat instruct"
examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
examplenew[eval_func_param_names.index('context')] = '' # no context
examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)

Expand Down Expand Up @@ -217,7 +218,8 @@ def main(
model, tokenizer, device = get_model(**locals())
model_state = [model, tokenizer, device, base_model]
fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions)
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
chat_context=chat_context)
else:
assert eval_sharegpt_prompts_only > 0

Expand Down Expand Up @@ -256,7 +258,8 @@ def get_response(*args, exi=0):
if eval_sharegpt_prompts_only > 0:
# only our own examples have this filled at moment
assert iinput in [None, ''], iinput # should be no iinput
assert context in [None, ''], context # should be no context
if not (chat_context and prompt_type == 'human_bot'):
assert context in [None, ''], context # should be no context
prompt = instruction
cutoff_len = 768 if is_low_mem else 2048
inputs = stokenizer(prompt, res,
Expand Down Expand Up @@ -625,10 +628,12 @@ def evaluate(
model_state0=None,
is_low_mem=None,
raise_generate_gpu_exceptions=None,
chat_context=None,
):
# ensure passed these
assert is_low_mem is not None
assert raise_generate_gpu_exceptions is not None
assert chat_context is not None

if debug:
locals_dict = locals().copy()
Expand Down Expand Up @@ -670,6 +675,10 @@ def evaluate(
instruction = instruction_nochat
iinput = iinput_nochat

if not context:
# get hidden context if have one
context = get_context(chat_context, prompt_type)

data_point = dict(context=context, instruction=instruction, input=iinput)
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
prompt = prompter.generate_prompt(data_point)
Expand Down Expand Up @@ -976,17 +985,17 @@ def get_generate_params(model_lower, chat,
num_return_sequences = min(num_beams, num_return_sequences or 1)
do_sample = False if do_sample is None else do_sample
else:
temperature = 0.1 if temperature is None else temperature
top_p = 0.75 if top_p is None else top_p
top_k = 40 if top_k is None else top_k
temperature = 0.4 if temperature is None else temperature
top_p = 0.85 if top_p is None else top_p
top_k = 70 if top_k is None else top_k
if chat:
num_beams = num_beams or 1
else:
num_beams = num_beams or 4
max_new_tokens = max_new_tokens or 256
repetition_penalty = repetition_penalty or 1.07
num_return_sequences = min(num_beams, num_return_sequences or 1)
do_sample = False if do_sample is None else do_sample
do_sample = True if do_sample is None else do_sample
# doesn't include chat, instruction_nochat, iinput_nochat, added later
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
Expand Down Expand Up @@ -1062,6 +1071,15 @@ def languages_covered():
return covered


def get_context(chat_context, prompt_type):
if chat_context and prompt_type == 'human_bot':
context0 = """<bot>: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand.
<human>: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed."""
else:
context0 = ''
return context0


def test_test_prompt(prompt_type='instruct', data_point=0):
example_data_point = example_data_points[data_point]
example_data_point.pop('output', None)
Expand Down
4 changes: 2 additions & 2 deletions gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _postprocess_chat_messages(self, chat_message: str):
visible=not is_public)
context = gr.Textbox(lines=3, label="System Pre-Context",
info="Directly pre-appended without prompt processing",
visible=not is_public and not kwargs['chat'])
visible=not is_public)
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
visible=not is_public)

Expand Down Expand Up @@ -862,7 +862,7 @@ def get_system_info():

input_args_list = ['model_state']
inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
'raise_generate_gpu_exceptions']
'raise_generate_gpu_exceptions', 'chat_context']


def get_inputs_list(inputs_dict, model_lower):
Expand Down