Skip to content

Commit

Permalink
#84 switch to new version with more convenient way
Browse files Browse the repository at this point in the history
  • Loading branch information
kirzharov committed Dec 4, 2020
1 parent 9ffea67 commit e89eaa3
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions bert_score/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,27 @@ def sent_encode(tokenizer, sent):
# for RoBERTa and GPT-2
import transformers

if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
if LooseVersion(transformers.__version__) >= LooseVersion("4.0.0"):
return tokenizer.encode(
sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.model_max_length,
truncation=True
)
elif LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
return tokenizer.encode(
sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len,
truncation=True
)
else:
return tokenizer.encode(sent, add_special_tokens=True, add_prefix_space=True, max_length=tokenizer.max_len)
else:
import transformers

if LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
if LooseVersion(transformers.__version__) >= LooseVersion("4.0.0"):
return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.model_max_length,
truncation=True)
elif LooseVersion(transformers.__version__) >= LooseVersion("3.0.0"):
return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len,
truncation=True)
else:
return tokenizer.encode(sent, add_special_tokens=True, max_length=tokenizer.max_len)

Expand Down

0 comments on commit e89eaa3

Please sign in to comment.