Skip to content

Commit

Permalink
Fixes #249
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudotensor committed Jun 8, 2023
1 parent 4f4fd98 commit 839b7f5
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions h2oai_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,47 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False,
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs

def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value] and \
hasattr(self.tokenizer, 'model_max_length'):
if hasattr(self.tokenizer, 'model_max_length'):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = self.tokenizer.model_max_length
verbose = False # FIXME: debug
else:
# unknown
model_max_length = None

verbose = False # FIXME: debug
if model_max_length is not None:
num_prompt_tokens = None
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https:/h2oai/h2ogpt/issues/192
prompt_tokens = self.tokenizer(prompt_text)['input_ids']
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
prompt_text = prompt_text[-model_max_length * chars_per_token:]
for trial in range(0, 3):
prompt_tokens = self.tokenizer(prompt_text)['input_ids']
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
prompt_text = prompt_text[-model_max_length * chars_per_token:]
if verbose:
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True)
else:
if verbose:
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
break

# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
#
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
# then give room for prompt
fudge = 20
else:
fudge = 0
assert num_prompt_tokens is not None
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
model_max_length - (num_prompt_tokens + fudge)))
if max_new_tokens < generate_kwargs['max_new_tokens']:
if verbose:
print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True)
prompt_tokens2 = self.tokenizer(prompt_text)['input_ids']
num_prompt_tokens2 = len(prompt_tokens2)
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
print("Reduced max_new_tokens from %s -> %s" % (generate_kwargs['max_new_tokens'], max_new_tokens))
generate_kwargs['max_new_tokens'] = max_new_tokens

data_point = dict(context='', instruction=prompt_text, input='')
if self.prompter is not None:
Expand Down

0 comments on commit 839b7f5

Please sign in to comment.