Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates AntonymAug to correct augmentation word selection criteria. #167

Merged
merged 2 commits into from
Oct 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions nlpaug/augmenter/word/antonym.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def __init__(self, name='Antonym_Aug', aug_min=1, aug_max=10, aug_p=0.3, lang='e
def skip_aug(self, token_idxes, tokens):
results = []
for token_idx in token_idxes:
# Based on https://arxiv.org/pdf/1809.02079.pdf for Antonyms,
# We choose only tokens which are Verbs, Adjectives, Adverbs
if tokens[token_idx][1] not in ['VB', 'VBD', 'VBZ', 'VBG', 'VBN', 'VBP',
'JJ', 'JJR', 'JJS',
'RB', 'RBR', 'RBS']:
continue

# Some word does not come with synonym/ antony. It will be excluded in lucky draw.
if tokens[token_idx][1] in ['DT']:
continue
Expand All @@ -66,9 +73,28 @@ def _get_aug_idxes(self, tokens):
code=WarningCode.WARNING_CODE_002, msg=WarningMessage.NO_WORD)
exception.output()
return None
if len(word_idxes) < aug_cnt:
aug_cnt = len(word_idxes)
aug_idexes = self.sample(word_idxes, aug_cnt)

aug_idexes = []
for aug_idx in word_idxes:
word_poses = PartOfSpeech.constituent2pos(tokens[aug_idx][1])
candidates = []
if word_poses is None or len(word_poses) == 0:
# Use every possible words as the mapping does not defined correctly
candidates.extend(self.model.predict(tokens[aug_idx][0]))
else:
for word_pos in word_poses:
candidates.extend(self.model.predict(tokens[aug_idx][0], pos=word_pos))

candidates = [c for c in candidates if c.lower() != tokens[aug_idx][0].lower()]

if len(candidates) > 0:
candidate = self.sample(candidates, 1)[0]
aug_idexes.append((aug_idx, candidate))

if len(aug_idexes) < aug_cnt:
aug_cnt = len(aug_idexes)

aug_idexes = self.sample(aug_idexes, aug_cnt)
return aug_idexes

def get_candidates(self, tokens, token_idx):
Expand All @@ -91,7 +117,13 @@ def substitute(self, data):

pos = self.model.pos_tag(doc.get_original_tokens())

aug_idxes = self._get_aug_idxes(pos)
aug_candidates = self._get_aug_idxes(pos)
if aug_candidates is None or len(aug_candidates) == 0:
if self.include_detail:
return data, []
return data

aug_idxes, candidates = zip(*aug_candidates)
if aug_idxes is None or len(aug_idxes) == 0:
if self.include_detail:
return data, []
Expand All @@ -101,14 +133,14 @@ def substitute(self, data):
# Skip if no augment for word
if aug_idx not in aug_idxes:
continue

candidates = self.get_candidates(pos, aug_idx)

if len(candidates) > 0:
candidate = self.sample(candidates, 1)[0]
candidate = candidate.replace("_", " ").replace("-", " ").lower()
substitute_token = self.align_capitalization(original_token, candidate)

if aug_idx == 0:
substitute_token = self.align_capitalization(original_token, substitute_token)

Expand Down
2 changes: 1 addition & 1 deletion nlpaug/model/word_dict/wordnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, lang, is_synonym=True):
import nltk
from nltk.corpus import wordnet
except ModuleNotFoundError:
raise ModuleNotFoundError('Missed nltk library. Install transfomers by `pip install nltk`')
raise ModuleNotFoundError('Missed nltk library. Install nltk by `pip install nltk`')

try:
# Check whether wordnet package is downloaded
Expand Down
11 changes: 11 additions & 0 deletions test/augmenter/word/test_antonym.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def setUpClass(cls):

def test_substitute(self):
texts = [
'Older people feel more youthful when they also feel in control.',
'Good bad',
'Heart patients may benefit more from exercise than healthy people.',
'Beer first or wine, either way might not be fine.'
Expand All @@ -29,6 +30,16 @@ def test_substitute(self):
augmented_text = aug.augment(text)
self.assertNotEqual(text, augmented_text)

def test_unable_to_substitute(self):
texts = [
'Insomnia, sleep apnea diagnoses up sharply in U.S. Army.'
]

for aug in self.augs:
for text in texts:
augmented_text = aug.augment(text)
self.assertEqual(text, augmented_text)

def test_skip_punctuation(self):
text = '. . . . ! ? # @'

Expand Down