diff --git a/bert_score/utils.py b/bert_score/utils.py index 3ec0a7c..455f5b8 100644 --- a/bert_score/utils.py +++ b/bert_score/utils.py @@ -107,19 +107,27 @@ def sent_encode(tokenizer, sent): # for RoBERTa and GPT-2 import transformers - if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): + if LooseVersion(transformers.__version__) >= LooseVersion("4.0.0"): return tokenizer.encode( sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.model_max_length, truncation=True ) + elif LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): + return tokenizer.encode( + sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, + truncation=True + ) else: return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len) else: import transformers - if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): + if LooseVersion(transformers.__version__) >= LooseVersion("4.0.0"): return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.model_max_length, truncation=True) + elif LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): + return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len, + truncation=True) else: return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len)