Skip to content

Commit

Permalink
Merge pull request #100 from h2oai/defaultcontext
Browse files Browse the repository at this point in the history
Give default context to help chatbot
  • Loading branch information
pseudotensor authored May 1, 2023
2 parents 8316cb4 + a61596a commit cefdef8
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 10 deletions.
2 changes: 1 addition & 1 deletion finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)

prompt = context
prompt = context if not reduced else ''

if input and promptA:
prompt += f"""{promptA}"""
Expand Down
32 changes: 25 additions & 7 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def main(
gradio_avoid_processing_markdown: bool = False,
chat: bool = True,
chat_history: int = 4096, # character length of chat context/history
chat_context: bool = False, # use default context if human_bot
stream_output: bool = True,
show_examples: bool = None,
verbose: bool = False,
Expand Down Expand Up @@ -182,7 +183,7 @@ def main(
assert not chat, "No gradio must use chat=False, uses nochat instruct"
examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
examplenew[eval_func_param_names.index('context')] = '' # no context
examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type)
examples.append(examplenew)
responses.append(output)

Expand Down Expand Up @@ -217,7 +218,8 @@ def main(
model, tokenizer, device = get_model(**locals())
model_state = [model, tokenizer, device, base_model]
fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions)
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
chat_context=chat_context)
else:
assert eval_sharegpt_prompts_only > 0

Expand Down Expand Up @@ -256,7 +258,8 @@ def get_response(*args, exi=0):
if eval_sharegpt_prompts_only > 0:
# only our own examples have this filled at moment
assert iinput in [None, ''], iinput # should be no iinput
assert context in [None, ''], context # should be no context
if not (chat_context and prompt_type == 'human_bot'):
assert context in [None, ''], context # should be no context
prompt = instruction
cutoff_len = 768 if is_low_mem else 2048
inputs = stokenizer(prompt, res,
Expand Down Expand Up @@ -625,10 +628,12 @@ def evaluate(
model_state0=None,
is_low_mem=None,
raise_generate_gpu_exceptions=None,
chat_context=None,
):
# ensure passed these
assert is_low_mem is not None
assert raise_generate_gpu_exceptions is not None
assert chat_context is not None

if debug:
locals_dict = locals().copy()
Expand Down Expand Up @@ -670,6 +675,10 @@ def evaluate(
instruction = instruction_nochat
iinput = iinput_nochat

if not context:
# get hidden context if have one
context = get_context(chat_context, prompt_type)

data_point = dict(context=context, instruction=instruction, input=iinput)
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
prompt = prompter.generate_prompt(data_point)
Expand Down Expand Up @@ -976,17 +985,17 @@ def get_generate_params(model_lower, chat,
num_return_sequences = min(num_beams, num_return_sequences or 1)
do_sample = False if do_sample is None else do_sample
else:
temperature = 0.1 if temperature is None else temperature
top_p = 0.75 if top_p is None else top_p
top_k = 40 if top_k is None else top_k
temperature = 0.4 if temperature is None else temperature
top_p = 0.85 if top_p is None else top_p
top_k = 70 if top_k is None else top_k
if chat:
num_beams = num_beams or 1
else:
num_beams = num_beams or 4
max_new_tokens = max_new_tokens or 256
repetition_penalty = repetition_penalty or 1.07
num_return_sequences = min(num_beams, num_return_sequences or 1)
do_sample = False if do_sample is None else do_sample
do_sample = True if do_sample is None else do_sample
# doesn't include chat, instruction_nochat, iinput_nochat, added later
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
Expand Down Expand Up @@ -1062,6 +1071,15 @@ def languages_covered():
return covered


def get_context(chat_context, prompt_type):
if chat_context and prompt_type == 'human_bot':
context0 = """<bot>: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand.
<human>: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed."""
else:
context0 = ''
return context0


def test_test_prompt(prompt_type='instruct', data_point=0):
example_data_point = example_data_points[data_point]
example_data_point.pop('output', None)
Expand Down
4 changes: 2 additions & 2 deletions gradio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def _postprocess_chat_messages(self, chat_message: str):
visible=not is_public)
context = gr.Textbox(lines=3, label="System Pre-Context",
info="Directly pre-appended without prompt processing",
visible=not is_public and not kwargs['chat'])
visible=not is_public)
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
visible=not is_public)

Expand Down Expand Up @@ -862,7 +862,7 @@ def get_system_info():

input_args_list = ['model_state']
inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
'raise_generate_gpu_exceptions']
'raise_generate_gpu_exceptions', 'chat_context']


def get_inputs_list(inputs_dict, model_lower):
Expand Down

0 comments on commit cefdef8

Please sign in to comment.