Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #270 #272

Merged
merged 6 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def main(

# allow set token directly
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
allow_upload_to_user_data = bool(int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data)))))
allow_upload_to_user_data = bool(
int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data)))))
allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data)))))
height = int(os.environ.get("HEIGHT", height))
h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors)))
Expand Down Expand Up @@ -1154,16 +1155,16 @@ def evaluate(
if chat:
# override, ignore user change
num_return_sequences = 1
stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device)
stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device,
model_max_length=tokenizer.model_max_length)

# limit prompt using token length from user, implicit, or model
_, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level,
model_max_length=tokenizer.model_max_length)
prompt = prompt[-max_prompt_length:]
inputs = tokenizer(prompt,
return_tensors="pt",
truncation=True,
max_length=max_length_tokenize)
if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
from h2oai_pipeline import H2OTextGenerationPipeline
prompt = H2OTextGenerationPipeline.limit_prompt(prompt, tokenizer, max_prompt_length=max_prompt_length)

inputs = tokenizer(prompt, return_tensors="pt")
if debug and len(inputs["input_ids"]) > 0:
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
input_ids = inputs["input_ids"].to(device)
Expand Down Expand Up @@ -1318,6 +1319,7 @@ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=20
if memory_restriction_level > 0:
max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
else:
# at least give room for 1 paragraph output
max_length_tokenize = model_max_length - 256
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
output_smallest = 30 * 4
Expand Down
37 changes: 23 additions & 14 deletions gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
from langchain.chains.question_answering import load_qa_chain
from langchain.docstore.document import Document
from langchain import PromptTemplate
Expand Down Expand Up @@ -331,7 +331,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
repetition_penalty=repetition_penalty,
num_return_sequences=num_return_sequences,
return_full_text=True,
handle_long_generation='hole')
handle_long_generation=None)
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0

if stream_output:
Expand Down Expand Up @@ -396,7 +396,7 @@ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
data = json.load(open(filename, "rt"))
page_content = list(data["query"]["pages"].values())[0]["extract"]
if take_head is not None and text_limit is not None:
page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
title_url = str(title).replace(' ', '_')
return Document(
page_content=page_content,
Expand Down Expand Up @@ -591,7 +591,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
docs1 = UnstructuredHTMLLoader(file_path=file).load()
add_meta(docs1, file)
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
add_meta(docs1, file)
Expand All @@ -617,7 +617,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
elif file.lower().endswith('.md'):
docs1 = UnstructuredMarkdownLoader(file).load()
add_meta(docs1, file)
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN)
elif file.lower().endswith('.enex'):
docs1 = EverNoteLoader(file).load()
add_meta(doc1, file)
Expand Down Expand Up @@ -682,6 +682,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
with open(file, "r") as f:
doc1 = Document(page_content=f.read(), metadata={"source": file})
add_meta(doc1, file)
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST)
elif file.lower().endswith('.pdf'):
env_gpt4all_file = ".env_gpt4all"
from dotenv import dotenv_values
Expand All @@ -704,6 +705,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
elif file.lower().endswith('.py'):
doc1 = PythonLoader(file).load()
add_meta(doc1, file)
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON)
elif file.lower().endswith('.toml'):
doc1 = TomlLoader(file).load()
add_meta(doc1, file)
Expand Down Expand Up @@ -1589,17 +1591,24 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
return ret, extra


def chunk_sources(sources, chunk=True, chunk_size=512):
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
if not chunk:
return sources
source_chunks = []
# Below for known separator
# splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
for source in sources:
# print(source.metadata['source'], flush=True)
for chunky in splitter.split_text(source.page_content):
source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
if not isinstance(sources, (list, tuple)):
# if just one document
sources = [sources]
if language and False:
# Bug in langchain, keep separator=True not working
# https:/hwchase17/langchain/issues/2836
# so avoid this for now
keep_separator = True
separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
else:
separators = ["\n\n", "\n", " ", ""]
keep_separator = False
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
separators=separators)
source_chunks = splitter.split_documents(sources)
return source_chunks


Expand Down
62 changes: 41 additions & 21 deletions h2oai_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,25 +51,37 @@ def __init__(self, *args, debug=False, chat=False, stream_output=False,
self.sanitize_bot_response = sanitize_bot_response
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs

def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
if hasattr(self.tokenizer, 'model_max_length'):
@staticmethod
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))

if hasattr(tokenizer, 'model_max_length'):
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
model_max_length = self.tokenizer.model_max_length
model_max_length = tokenizer.model_max_length
if max_prompt_length is not None:
model_max_length = min(model_max_length, max_prompt_length)
# cut at some upper likely limit to avoid excessive tokenization etc
# upper bound of 10 chars/token, e.g. special chars sometimes are long
if len(prompt_text) > model_max_length * 10:
len0 = len(prompt_text)
prompt_text = prompt_text[-model_max_length * 10:]
if verbose:
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
else:
# unknown
model_max_length = None

verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
if model_max_length is not None:
num_prompt_tokens = None
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
# For https:/h2oai/h2ogpt/issues/192
for trial in range(0, 3):
prompt_tokens = self.tokenizer(prompt_text)['input_ids']
prompt_tokens = tokenizer(prompt_text)['input_ids']
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > model_max_length:
# conservative by using int()
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
# keep tail, where question is if using langchain
prompt_text = prompt_text[-model_max_length * chars_per_token:]
if verbose:
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
Expand All @@ -79,28 +91,35 @@ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **gene
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
break

# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
#
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
# then give room for prompt
fudge = 20
else:
fudge = 0
assert num_prompt_tokens is not None
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
model_max_length - (num_prompt_tokens + fudge)))
if max_new_tokens < generate_kwargs['max_new_tokens']:
if verbose:
print("Reduced max_new_tokens from %s -> %s" % (generate_kwargs['max_new_tokens'], max_new_tokens))
generate_kwargs['max_new_tokens'] = max_new_tokens
# Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
if False:
# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
#
assert num_prompt_tokens is not None
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
# then give room for prompt
fudge = 20
else:
fudge = 0
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
model_max_length - (num_prompt_tokens + fudge)))
if max_new_tokens < generate_kwargs['max_new_tokens']:
if verbose:
print("Reduced max_new_tokens from %s -> %s" % (
generate_kwargs['max_new_tokens'], max_new_tokens))
generate_kwargs['max_new_tokens'] = max_new_tokens
return prompt_text

def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
prompt_text = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)

data_point = dict(context='', instruction=prompt_text, input='')
if self.prompter is not None:
prompt_text = self.prompter.generate_prompt(data_point)
self.prompt_text = prompt_text
if handle_long_generation is None:
# forces truncation of inputs to avoid critical failure
handle_long_generation = 'hole'
handle_long_generation = None # disable with new approaches
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
**generate_kwargs)

Expand All @@ -123,7 +142,8 @@ def _forward(self, model_inputs, **generate_kwargs):
if self.can_stop:
stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
self.tokenizer, self.device,
human=self.human, bot=self.bot)
human=self.human, bot=self.bot,
model_max_length=self.tokenizer.model_max_length)
generate_kwargs['stopping_criteria'] = stopping_criteria
# return super()._forward(model_inputs, **generate_kwargs)
return self.__forward(model_inputs, **generate_kwargs)
Expand Down
11 changes: 8 additions & 3 deletions stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

class StoppingCriteriaSub(StoppingCriteria):

def __init__(self, stops=[], encounters=[], device="cuda"):
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
super().__init__()
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
self.encounters = encounters
self.stops = [stop.to(device) for stop in stops]
self.num_stops = [0] * len(stops)
self.model_max_length = model_max_length

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stopi, stop in enumerate(self.stops):
Expand All @@ -20,12 +21,15 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
# print("Stopped", flush=True)
return True
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
# critical limit
return True
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
return False


def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:"):
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
# FIXME: prompt_dict unused currently
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
if prompt_type == PromptType.human_bot.name:
Expand Down Expand Up @@ -67,7 +71,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:',
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
# build stopper
stopping_criteria = StoppingCriteriaList(
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
model_max_length=model_max_length)])
else:
stopping_criteria = StoppingCriteriaList()
return stopping_criteria
7 changes: 5 additions & 2 deletions tests/test_client_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_client_chat_nostream_gpt4all_llama():
def test_client_chat_nostream_llama7b():
prompt_type = get_llama()
res_dict, client = run_client_chat_with_server(stream_output=False, base_model='llama', prompt_type=prompt_type)
assert 'I’m a software engineer' in res_dict['response']
assert "I’m a software engineer" in res_dict['response'] or "I'm a software engineer" in res_dict['response']


def run_client_chat_with_server(prompt='Who are you?', stream_output=False, max_new_tokens=256,
Expand Down Expand Up @@ -228,7 +228,10 @@ def test_client_chat_stream_langchain_steps(max_new_tokens, top_k_docs):
# odd answer since no whisper docs, but still shows some docs at very low score
assert ('h2oGPT' in res_dict['response'] or
'A chatbot that can whisper to you' in res_dict['response'] or
'whisper is a simple' in res_dict['response']) \
'whisper is a simple' in res_dict['response'] or
'Whisper is a tool for generating text from a model' in res_dict['response'] or
'Whisper is a chatbot platform' in res_dict['response']
) \
and '.md' in res_dict['response']


Expand Down