Skip to content

Commit

Permalink
Fix #196
Browse files Browse the repository at this point in the history
  • Loading branch information
makcedward committed Jan 4, 2021
1 parent 0a0ef60 commit b06037b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ pip install librosa>=0.7.1 matplotlib

### 1.1.2dev, Dec, 2020
* Add NormalizeAug (audio) and PolarityInverseAug (audio)
* Fix [#191](https:/makcedward/nlpaug/issues/191), [#192](https:/makcedward/nlpaug/issues/192), [#194](https:/makcedward/nlpaug/issues/194)
* Fix [#191](https:/makcedward/nlpaug/issues/191), [#192](https:/makcedward/nlpaug/issues/192), [#194](https:/makcedward/nlpaug/issues/194), Fix [#196](https:/makcedward/nlpaug/issues/196)

See [changelog](https:/makcedward/nlpaug/blob/master/CHANGE.md) for more details.

Expand Down
38 changes: 20 additions & 18 deletions nlpaug/augmenter/sentence/context_word_embs_sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
CONTEXT_WORD_EMBS_SENTENCE_MODELS = {}


def init_context_word_embs_sentence_model(model_path, device, force_reload=False, temperature=1.0, top_k=None,
top_p=None, optimize=None, silence=True):
def init_context_word_embs_sentence_model(model_path, model_type, device, force_reload=False, temperature=1.0,
top_k=None, top_p=None, optimize=None, silence=True):

global CONTEXT_WORD_EMBS_SENTENCE_MODELS

model_name = os.path.basename(model_path)
Expand All @@ -30,10 +31,10 @@ def init_context_word_embs_sentence_model(model_path, device, force_reload=False
CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name].silence = silence
return CONTEXT_WORD_EMBS_SENTENCE_MODELS[model_name]

if 'xlnet' in model_path:
if model_type == 'xlnet':
model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p,
optimize=optimize, silence=True)
elif 'gpt2' in model_path:
elif model_type == 'gpt2':
model = nml.Gpt2(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p,
optimize=optimize, silence=True)
else:
Expand All @@ -50,6 +51,8 @@ class ContextualWordEmbsForSentenceAug(SentenceAugmenter):
:param str model_path: Model name or model path. It used transformers to load the model. Tested
'xlnet-base-cased', 'gpt2', 'distilgpt2'. If you want to reduce inference time, you may select `distilgpt2`.
:param str model_type: Type of model. For XLNet model, use 'xlnet'. For GPT2 or distilgpt2 model, use 'gpt'. If
no value is provided, will determine from model name.
:param float temperature: Controlling randomness. Default value is 1 and lower temperature results in less random
behavior
:param int top_k: Controlling lucky draw pool. Top k score token will be used for augmentation. Larger k, more
Expand All @@ -71,31 +74,30 @@ class ContextualWordEmbsForSentenceAug(SentenceAugmenter):
>>> aug = nas.ContextualWordEmbsForSentenceAug()
"""

def __init__(self, model_path='distilgpt2', temperature=1.0, top_k=100, top_p=None,
def __init__(self, model_path='distilgpt2', model_type='', temperature=1.0, top_k=100, top_p=None,
name='ContextualWordEmbsForSentence_Aug',
device='cpu', force_reload=False, optimize=None, verbose=0, silence=True):
super().__init__(
action=Action.INSERT, name=name, tokenizer=None, stopwords=None, device=device,
include_detail=False, parallelable=True, verbose=verbose)
self.model_path = model_path
self.model_type = model_type if model_type != '' else self.check_model_type()
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.silence = silence

self._init()
self.model = self.get_model(
model_path=model_path, device=device, force_reload=force_reload, temperature=temperature, top_k=top_k,
top_p=top_p, optimize=optimize, silence=silence)
model_path=model_path, model_type=self.model_type, device=device, force_reload=force_reload,
temperature=temperature, top_k=top_k, top_p=top_p, optimize=optimize, silence=silence)
self.device = self.model.device

def _init(self):
if 'xlnet' in self.model_path:
self.model_type = 'xlnet'
elif 'gpt2' in self.model_path:
self.model_type = 'gpt2'
else:
self.model_type = ''
def check_model_type(self):
if 'xlnet' in self.model_path.lower():
return 'xlnet'
elif 'gpt2' in self.model_path.lower():
return 'gpt2'
return ''

def insert(self, data):
if not data:
Expand Down Expand Up @@ -188,7 +190,7 @@ def insert(self, data):
return results[0]

@classmethod
def get_model(cls, model_path, device='cuda', force_reload=False, temperature=1.0, top_k=None, top_p=0.0,
def get_model(cls, model_path, model_type, device='cuda', force_reload=False, temperature=1.0, top_k=None, top_p=0.0,
optimize=None, silence=True):
return init_context_word_embs_sentence_model(model_path, device, force_reload, temperature, top_k, top_p,
optimize=optimize, silence=silence)
return init_context_word_embs_sentence_model(model_path, model_type, device, force_reload, temperature, top_k,
top_p, optimize=optimize, silence=silence)
48 changes: 25 additions & 23 deletions nlpaug/augmenter/word/context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
CONTEXT_WORD_EMBS_MODELS = {}


def init_context_word_embs_model(model_path, device, force_reload=False, temperature=1.0, top_k=None, top_p=None,
def init_context_word_embs_model(model_path, model_type, device, force_reload=False, temperature=1.0, top_k=None, top_p=None,
optimize=None, silence=True):
global CONTEXT_WORD_EMBS_MODELS

Expand All @@ -28,13 +28,13 @@ def init_context_word_embs_model(model_path, device, force_reload=False, tempera
CONTEXT_WORD_EMBS_MODELS[model_name].silence = silence
return CONTEXT_WORD_EMBS_MODELS[model_name]

if 'distilbert' in model_path.lower():
if model_type == 'distilbert':
model = nml.DistilBert(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, silence=silence)
elif 'roberta' in model_path.lower():
elif model_type == 'roberta':
model = nml.Roberta(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, silence=silence)
elif 'bert' in model_path.lower():
elif model_type == 'bert':
model = nml.Bert(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, silence=silence)
elif 'xlnet' in model_path.lower():
elif model_type == 'xlnet':
model = nml.XlNet(model_path, device=device, temperature=temperature, top_k=top_k, top_p=top_p, optimize=optimize,
silence=silence)
else:
Expand All @@ -43,7 +43,6 @@ def init_context_word_embs_model(model_path, device, force_reload=False, tempera
CONTEXT_WORD_EMBS_MODELS[model_name] = model
return model


class ContextualWordEmbsAug(WordAugmenter):
# https://arxiv.org/pdf/1805.06201.pdf, https://arxiv.org/pdf/2003.02245.pdf
"""
Expand All @@ -52,6 +51,9 @@ class ContextualWordEmbsAug(WordAugmenter):
:param str model_path: Model name or model path. It used transformers to load the model. Tested
'bert-base-uncased', 'bert-base-cased', 'distilbert-base-uncased', 'roberta-base', 'distilroberta-base',
'xlnet-base-cased'.
:param str model_type: Type of model. For BERT model, use 'bert'. For XLNet model, use 'xlnet'. For DistilBERT
model, use 'distilbert'. For RoBERTa model, use 'roberta'. If no value is provided, will determine from
model name.
:param str action: Either 'insert or 'substitute'. If value is 'insert', a new word will be injected to random
position according to contextual word embeddings calculation. If value is 'substitute', word will be replaced
according to contextual embeddings calculation
Expand Down Expand Up @@ -82,7 +84,7 @@ class ContextualWordEmbsAug(WordAugmenter):
>>> aug = naw.ContextualWordEmbsAug()
"""

def __init__(self, model_path='bert-base-uncased', action="substitute", temperature=1.0, top_k=100, top_p=None,
def __init__(self, model_path='bert-base-uncased', model_type='', action="substitute", temperature=1.0, top_k=100, top_p=None,
name='ContextualWordEmbs_Aug', aug_min=1, aug_max=10, aug_p=0.3, stopwords=None,
device='cpu', force_reload=False, optimize=None, stopwords_regex=None,
verbose=0, silence=True,):
Expand All @@ -91,15 +93,15 @@ def __init__(self, model_path='bert-base-uncased', action="substitute", temperat
device=device, stopwords=stopwords, verbose=verbose, stopwords_regex=stopwords_regex,
include_detail=False, parallelable=True)
self.model_path = model_path
self.model_type = model_type if model_type != '' else self.check_model_type()
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.silence = silence

self._init()
self.model = self.get_model(
model_path=model_path, device=device, force_reload=force_reload, temperature=temperature, top_k=top_k,
top_p=top_p, optimize=optimize, silence=silence)
model_path=model_path, model_type=self.model_type, device=device, force_reload=force_reload, temperature=temperature,
top_k=top_k, top_p=top_p, optimize=optimize, silence=silence)
# Override stopwords
if stopwords is not None and self.model_type in ['xlnet', 'roberta']:
stopwords = [self.stopwords]
Expand All @@ -111,17 +113,16 @@ def __init__(self, model_path='bert-base-uncased', action="substitute", temperat
"""
self.max_num_token = self.model.get_max_num_token()

def _init(self):
if 'xlnet' in self.model_path:
self.model_type = 'xlnet'
elif 'distilbert' in self.model_path:
self.model_type = 'distilbert'
elif 'roberta' in self.model_path:
self.model_type = 'roberta'
elif 'bert' in self.model_path:
self.model_type = 'bert'
else:
self.model_type = ''
def check_model_type(self):
if 'xlnet' in self.model_path.lower():
return 'xlnet'
elif 'distilbert' in self.model_path.lower():
return 'distilbert'
elif 'roberta' in self.model_path.lower():
return 'roberta'
elif 'bert' in self.model_path.lower():
return 'bert'
return ''

def is_stop_words(self, token):
if self.model_type in ['bert', 'distilbert']:
Expand Down Expand Up @@ -449,6 +450,7 @@ def substitute(self, data):
return augmented_texts[0]

@classmethod
def get_model(cls, model_path, device='cuda', force_reload=False, temperature=1.0, top_k=None, top_p=0.0,
def get_model(cls, model_path, model_type, device='cuda', force_reload=False, temperature=1.0, top_k=None, top_p=0.0,
optimize=None, silence=True):
return init_context_word_embs_model(model_path, device, force_reload, temperature, top_k, top_p, optimize, silence)
return init_context_word_embs_model(model_path, model_type, device, force_reload, temperature, top_k, top_p,
optimize, silence)
5 changes: 5 additions & 0 deletions test/augmenter/word/test_context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def test_fast_tokenizer(self):
aug.augment("Мозг — это машина которая пытается снизить ошибку в прогнозе.")
self.assertTrue(True)

def test_model_type(self):
aug = naw.ContextualWordEmbsAug(model_path="blinoff/roberta-base-russian-v0", model_type='roberta', force_reload=True)
aug.augment("Мозг — это машина которая пытается снизить ошибку в прогнозе.")
self.assertTrue(True)

def test_contextual_word_embs(self):
# self.execute_by_device('cuda')
self.execute_by_device('cpu')
Expand Down

0 comments on commit b06037b

Please sign in to comment.