diff --git a/tests/rc/predictors/transformer_qa_test.py b/tests/rc/predictors/transformer_qa_test.py index c64d8cb5e..6c5847cb5 100644 --- a/tests/rc/predictors/transformer_qa_test.py +++ b/tests/rc/predictors/transformer_qa_test.py @@ -20,12 +20,10 @@ def test_predict_single_instance(self): "What is love?", "Baby don't hurt me, don't hurt me, no more." ) span_start, span_end = prediction["best_span"] - assert 0 <= span_start <= span_end - assert ( - "best_span_str" in prediction - and isinstance(prediction["best_span_str"], str) - and len(prediction["best_span_str"]) > 0 - ) + assert -1 <= span_start <= span_end + assert "best_span_str" in prediction and isinstance(prediction["best_span_str"], str) + if span_start > -1: + assert len(prediction["best_span_str"]) > 0 def test_predict_long_instance(self): # We use a short context and a long context, so that the long context has to be broken into multiple