From a61596a46a16ed7f5cace48a6c93e139ad52c20c Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Mon, 1 May 2023 02:03:43 -0700 Subject: [PATCH] Give default context to help chatbot --- finetune.py | 2 +- generate.py | 32 +++++++++++++++++++++++++------- gradio_runner.py | 4 ++-- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/finetune.py b/finetune.py index 3b0ae0ed6..f921cd346 100644 --- a/finetune.py +++ b/finetune.py @@ -869,7 +869,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced): assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced) - prompt = context + prompt = context if not reduced else '' if input and promptA: prompt += f"""{promptA}""" diff --git a/generate.py b/generate.py index cd1acaf9c..44ee59a12 100644 --- a/generate.py +++ b/generate.py @@ -65,6 +65,7 @@ def main( gradio_avoid_processing_markdown: bool = False, chat: bool = True, chat_history: int = 4096, # character length of chat context/history + chat_context: bool = False, # use default context if human_bot stream_output: bool = True, show_examples: bool = None, verbose: bool = False, @@ -182,7 +183,7 @@ def main( assert not chat, "No gradio must use chat=False, uses nochat instruct" examplenew[eval_func_param_names.index('instruction_nochat')] = instruction examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input - examplenew[eval_func_param_names.index('context')] = '' # no context + examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type) examples.append(examplenew) responses.append(output) @@ -217,7 +218,8 @@ def main( model, tokenizer, device = get_model(**locals()) model_state = [model, tokenizer, device, base_model] fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem, - raise_generate_gpu_exceptions=raise_generate_gpu_exceptions) + raise_generate_gpu_exceptions=raise_generate_gpu_exceptions, + chat_context=chat_context) else: assert eval_sharegpt_prompts_only > 0 @@ -256,7 +258,8 @@ def get_response(*args, exi=0): if eval_sharegpt_prompts_only > 0: # only our own examples have this filled at moment assert iinput in [None, ''], iinput # should be no iinput - assert context in [None, ''], context # should be no context + if not (chat_context and prompt_type == 'human_bot'): + assert context in [None, ''], context # should be no context prompt = instruction cutoff_len = 768 if is_low_mem else 2048 inputs = stokenizer(prompt, res, @@ -625,10 +628,12 @@ def evaluate( model_state0=None, is_low_mem=None, raise_generate_gpu_exceptions=None, + chat_context=None, ): # ensure passed these assert is_low_mem is not None assert raise_generate_gpu_exceptions is not None + assert chat_context is not None if debug: locals_dict = locals().copy() @@ -670,6 +675,10 @@ def evaluate( instruction = instruction_nochat iinput = iinput_nochat + if not context: + # get hidden context if have one + context = get_context(chat_context, prompt_type) + data_point = dict(context=context, instruction=instruction, input=iinput) prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output) prompt = prompter.generate_prompt(data_point) @@ -976,9 +985,9 @@ def get_generate_params(model_lower, chat, num_return_sequences = min(num_beams, num_return_sequences or 1) do_sample = False if do_sample is None else do_sample else: - temperature = 0.1 if temperature is None else temperature - top_p = 0.75 if top_p is None else top_p - top_k = 40 if top_k is None else top_k + temperature = 0.4 if temperature is None else temperature + top_p = 0.85 if top_p is None else top_p + top_k = 70 if top_k is None else top_k if chat: num_beams = num_beams or 1 else: @@ -986,7 +995,7 @@ def get_generate_params(model_lower, chat, max_new_tokens = max_new_tokens or 256 repetition_penalty = repetition_penalty or 1.07 num_return_sequences = min(num_beams, num_return_sequences or 1) - do_sample = False if do_sample is None else do_sample + do_sample = True if do_sample is None else do_sample # doesn't include chat, instruction_nochat, iinput_nochat, added later params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample] @@ -1062,6 +1071,15 @@ def languages_covered(): return covered +def get_context(chat_context, prompt_type): + if chat_context and prompt_type == 'human_bot': + context0 = """: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand. +: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed.""" + else: + context0 = '' + return context0 + + def test_test_prompt(prompt_type='instruct', data_point=0): example_data_point = example_data_points[data_point] example_data_point.pop('output', None) diff --git a/gradio_runner.py b/gradio_runner.py index 442cc71a3..50367720d 100644 --- a/gradio_runner.py +++ b/gradio_runner.py @@ -274,7 +274,7 @@ def _postprocess_chat_messages(self, chat_message: str): visible=not is_public) context = gr.Textbox(lines=3, label="System Pre-Context", info="Directly pre-appended without prompt processing", - visible=not is_public and not kwargs['chat']) + visible=not is_public) chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'], visible=not is_public) @@ -862,7 +862,7 @@ def get_system_info(): input_args_list = ['model_state'] inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem', - 'raise_generate_gpu_exceptions'] + 'raise_generate_gpu_exceptions', 'chat_context'] def get_inputs_list(inputs_dict, model_lower):