From 7435645ec4ee5bcc77c658138a3e8b6e5c1ea2e3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 9 May 2022 16:53:01 +0200 Subject: [PATCH 1/2] LogSumExp trick `question_answering` pipeline. --- src/transformers/pipelines/question_answering.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index c629f703a030f0..bbffa3471f825d 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -398,8 +398,11 @@ def postprocess( end_ = np.where(undesired_tokens_mask, -10000.0, end_) # Normalize logits and spans to retrieve the answer - start_ = np.exp(start_ - np.log(np.sum(np.exp(start_), axis=-1, keepdims=True))) - end_ = np.exp(end_ - np.log(np.sum(np.exp(end_), axis=-1, keepdims=True))) + start_ = np.exp(start_ - start_.max(axis=-1, keepdims=True)) + start_ = start_ / start_.sum() + + end_ = np.exp(end_ - end_.max(axis=-1, keepdims=True)) + end_ = end_ / end_.sum() if handle_impossible_answer: min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item()) From b9b39175f882824f6353eca146a944d9b9bcc82b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 9 May 2022 18:46:54 +0200 Subject: [PATCH 2/2] Adding a failing test. --- .../test_pipelines_question_answering.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/pipelines/test_pipelines_question_answering.py b/tests/pipelines/test_pipelines_question_answering.py index e37fa12776835f..844ed0b68335ea 100644 --- a/tests/pipelines/test_pipelines_question_answering.py +++ b/tests/pipelines/test_pipelines_question_answering.py @@ -111,12 +111,47 @@ def test_small_model_pt(self): question_answerer = pipeline( "question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad" ) + outputs = question_answerer( question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." ) self.assertEqual(nested_simplify(outputs), {"score": 0.01, "start": 0, "end": 11, "answer": "HuggingFace"}) + @require_torch + def test_small_model_pt_softmax_trick(self): + question_answerer = pipeline( + "question-answering", model="sshleifer/tiny-distilbert-base-cased-distilled-squad" + ) + + real_postprocess = question_answerer.postprocess + + # Tweak start and stop to make sure we encounter the softmax logits + # bug. + def ensure_large_logits_postprocess( + model_outputs, + top_k=1, + handle_impossible_answer=False, + max_answer_len=15, + ): + for output in model_outputs: + output["start"] = output["start"] * 1e6 + output["end"] = output["end"] * 1e6 + return real_postprocess( + model_outputs, + top_k=top_k, + handle_impossible_answer=handle_impossible_answer, + max_answer_len=max_answer_len, + ) + + question_answerer.postprocess = ensure_large_logits_postprocess + + outputs = question_answerer( + question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris." + ) + + self.assertEqual(nested_simplify(outputs), {"score": 0.028, "start": 0, "end": 11, "answer": "HuggingFace"}) + @slow @require_torch def test_small_model_long_context_cls_slow(self):