Skip to content

Commit

Permalink
adopt transformers 4.x API
Browse files Browse the repository at this point in the history
  • Loading branch information
makcedward committed Dec 11, 2020
1 parent f40cbf3 commit 935e68a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 5 deletions.
8 changes: 7 additions & 1 deletion nlpaug/model/lang_models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,13 @@ def is_skip_candidate(self, candidate):
return candidate.startswith(self.SUBWORD_PREFIX)

def token2id(self, token):
return self.tokenizer._convert_token_to_id(token)
# Iseue 181: TokenizerFast have convert_tokens_to_ids but not convert_tokens_to_id
if 'TokenizerFast' in self.tokenizer.__class__.__name__:
# New transformers API
return self.tokenizer.convert_tokens_to_ids(token)
else:
# Old transformers API
return self.tokenizer._convert_token_to_id(token)

def id2token(self, _id):
return self.tokenizer._convert_id_to_token(_id)
Expand Down
8 changes: 7 additions & 1 deletion nlpaug/model/lang_models/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def is_skip_candidate(self, candidate):
return candidate[:2] == self.SUBWORD_PREFIX

def token2id(self, token):
return self.tokenizer._convert_token_to_id(token)
# Iseue 181: TokenizerFast have convert_tokens_to_ids but not convert_tokens_to_id
if 'TokenizerFast' in self.tokenizer.__class__.__name__:
# New transformers API
return self.tokenizer.convert_tokens_to_ids(token)
else:
# Old transformers API
return self.tokenizer._convert_token_to_id(token)

def id2token(self, _id):
return self.tokenizer._convert_id_to_token(_id)
Expand Down
10 changes: 8 additions & 2 deletions nlpaug/model/lang_models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Roberta(LanguageModels):
START_TOKEN = '<s>'
SEPARATOR_TOKEN = '</s>'
MASK_TOKEN = '<mask>'
PAD_TOKEN = '<pad>',
PAD_TOKEN = '<pad>'
UNKNOWN_TOKEN = '<unk>'
SUBWORD_PREFIX = 'Ġ'

Expand Down Expand Up @@ -48,7 +48,13 @@ def get_max_num_token(self):
return self.model.config.max_position_embeddings - 2 * 5

def token2id(self, token):
return self.tokenizer._convert_token_to_id(token)
# Iseue 181: TokenizerFast have convert_tokens_to_ids but not convert_tokens_to_id
if 'TokenizerFast' in self.tokenizer.__class__.__name__:
# New transformers API
return self.tokenizer.convert_tokens_to_ids(token)
else:
# Old transformers API
return self.tokenizer._convert_token_to_id(token)

def id2token(self, _id):
return self.tokenizer._convert_id_to_token(_id)
Expand Down
8 changes: 7 additions & 1 deletion nlpaug/model/lang_models/xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ def get_max_num_token(self):
return 500

def token2id(self, token):
return self.tokenizer._convert_token_to_id(token)
# Iseue 181: TokenizerFast have convert_tokens_to_ids but not convert_tokens_to_id
if 'TokenizerFast' in self.tokenizer.__class__.__name__:
# New transformers API
return self.tokenizer.convert_tokens_to_ids(token)
else:
# Old transformers API
return self.tokenizer._convert_token_to_id(token)

def id2token(self, _id):
return self.tokenizer._convert_id_to_token(_id)
Expand Down

0 comments on commit 935e68a

Please sign in to comment.