From 2975c93d28ed075a55cdaea74de1272cf7d9a145 Mon Sep 17 00:00:00 2001 From: davebulaval Date: Tue, 1 Dec 2020 13:15:29 -0500 Subject: [PATCH] fixed tokenizer max_len error with transformers update --- bert_score/utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/bert_score/utils.py b/bert_score/utils.py index 649ebeb..082476e 100644 --- a/bert_score/utils.py +++ b/bert_score/utils.py @@ -109,17 +109,20 @@ def sent_encode(tokenizer, sent): if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): return tokenizer.encode( - sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len, truncation=True + sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.model_max_length, + truncation=True ) else: - return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len) + return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, + max_length=tokenizer.model_max_length) else: import transformers if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"): - return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len, truncation=True) + return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.model_max_length, + truncation=True) else: - return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len) + return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.model_max_length) def get_model(model_type, num_layers, all_layers=None):