-
Notifications
You must be signed in to change notification settings - Fork 3
/
word_level_augment.py
265 lines (224 loc) · 8.11 KB
/
word_level_augment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# coding=utf-8
# Copyright 2019 The Google UDA Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Word level augmentations including Replace words with uniform random words or TF-IDF based word replacement.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import string
from absl import flags
import numpy as np
import logging
FLAGS = flags.FLAGS
printable = set(string.printable)
def filter_unicode(st):
return "".join([c for c in st if c in printable])
class EfficientRandomGen(object):
"""A base class that generate multiple random numbers at the same time."""
def reset_random_prob(self):
"""Generate many random numbers at the same time and cache them."""
cache_len = 100000
self.random_prob_cache = np.random.random(size=(cache_len,))
self.random_prob_ptr = cache_len - 1
def get_random_prob(self):
"""Get a random number."""
value = self.random_prob_cache[self.random_prob_ptr]
self.random_prob_ptr -= 1
if self.random_prob_ptr == -1:
self.reset_random_prob()
return value
def get_random_token(self):
"""Get a random token."""
token = self.token_list[self.token_ptr]
self.token_ptr -= 1
if self.token_ptr == -1:
self.reset_token_list()
return token
class UnifRep(EfficientRandomGen):
"""Uniformly replace word with random words in the vocab."""
def __init__(self, token_prob, vocab):
self.token_prob = token_prob
self.vocab_size = len(vocab)
self.vocab = vocab
self.reset_token_list()
self.reset_random_prob()
def __call__(self, example):
example.word_list_a = self.replace_tokens(example.word_list_a)
if example.text_b:
example.word_list_b = self.replace_tokens(example.word_list_b)
return example
def replace_tokens(self, tokens):
"""Replace tokens randomly."""
if len(tokens) >= 3:
if np.random.random() < 0.001:
show_example = True
else:
show_example = False
if show_example:
logging.info("before augment: {:s}".format(
filter_unicode(" ".join(tokens))))
for i in range(len(tokens)):
if self.get_random_prob() < self.token_prob:
tokens[i] = self.get_random_token()
if show_example:
logging.info("after augment: {:s}".format(
filter_unicode(" ".join(tokens))))
return tokens
def reset_token_list(self):
"""Generate many random tokens at the same time and cache them."""
self.token_list = self.vocab.keys()
self.token_ptr = len(self.token_list) - 1
np.random.shuffle(self.token_list)
def Convert(string):
li = list(string.split(" "))
return li
def get_data_stats(examples):
"""Compute the IDF score for each word. Then compute the TF-IDF score."""
word_doc_freq = collections.defaultdict(int)
# Compute IDF
for i in range(len(examples)):
cur_word_dict = {}
# import pdb
# pdb.set_trace()
cur_sent = copy.deepcopy(Convert(examples[i].text_a))
if examples[i].text_b:
cur_sent += Convert(examples[i].text_b)
for word in cur_sent:
cur_word_dict[word] = 1
for word in cur_word_dict:
word_doc_freq[word] += 1
idf = {}
for word in word_doc_freq:
idf[word] = math.log(len(examples) * 1. / word_doc_freq[word])
# Compute TF-IDF
tf_idf = {}
for i in range(len(examples)):
cur_word_dict = {}
cur_sent = copy.deepcopy(Convert(examples[i].text_a))
if examples[i].text_b:
cur_sent += Convert(examples[i].text_b)
for word in cur_sent:
if word not in tf_idf:
tf_idf[word] = 0
tf_idf[word] += 1. / len(cur_sent) * idf[word]
return {
"idf": idf,
"tf_idf": tf_idf,
}
class TfIdfWordRep(EfficientRandomGen):
"""TF-IDF Based Word Replacement."""
def __init__(self, token_prob, data_stats):
super(TfIdfWordRep, self).__init__()
self.token_prob = token_prob
self.data_stats = data_stats
self.idf = data_stats["idf"]
self.tf_idf = data_stats["tf_idf"]
data_stats = copy.deepcopy(data_stats)
tf_idf_items = data_stats["tf_idf"].items()
tf_idf_items = sorted(tf_idf_items, key=lambda item: -item[1])
self.tf_idf_keys = []
self.tf_idf_values = []
for key, value in tf_idf_items:
self.tf_idf_keys += [key]
self.tf_idf_values += [value]
self.normalized_tf_idf = np.array(self.tf_idf_values)
self.normalized_tf_idf = (self.normalized_tf_idf.max()
- self.normalized_tf_idf)
self.normalized_tf_idf = (self.normalized_tf_idf
/ self.normalized_tf_idf.sum())
self.reset_token_list()
self.reset_random_prob()
def get_replace_prob(self, all_words):
"""Compute the probability of replacing tokens in a sentence."""
cur_tf_idf = collections.defaultdict(int)
for word in all_words:
cur_tf_idf[word] += 1. / len(all_words) * self.idf[word]
replace_prob = []
for word in all_words:
replace_prob += [cur_tf_idf[word]]
replace_prob = np.array(replace_prob)
replace_prob = np.max(replace_prob) - replace_prob
replace_prob = (replace_prob / replace_prob.sum() *
self.token_prob * len(all_words))
return replace_prob
def __call__(self, example):
if self.get_random_prob() < 0.001:
show_example = True
else:
show_example = False
all_words = copy.deepcopy(Convert(example.text_a))
if example.text_b:
all_words += Convert(example.text_b)
if show_example:
logging.info("before tf_idf_unif aug: {:s}".format(
filter_unicode(" ".join(all_words))))
replace_prob = self.get_replace_prob(all_words)
example.text_a = self.replace_tokens(
Convert(example.text_a),
replace_prob[:len(Convert(example.text_a))]
)
if example.text_b:
example.text_b = self.replace_tokens(
Convert(example.text_b),
replace_prob[len(example.text_a):]
)
if show_example:
all_words = copy.deepcopy(example.text_a)
if example.text_b:
all_words += example.text_b
logging.info("after tf_idf_unif aug: {:s}".format(
filter_unicode(" ".join(all_words))))
return example
def replace_tokens(self, word_list, replace_prob):
"""Replace tokens in a sentence."""
for i in range(len(word_list)):
if self.get_random_prob() < replace_prob[i]:
# import pdb
# pdb.set_trace()
# print(i, word_list[i], 'count')
word_list[i] = self.get_random_token()
return word_list
def reset_token_list(self):
cache_len = len(self.tf_idf_keys)
token_list_idx = np.random.choice(
cache_len, (cache_len,), p=self.normalized_tf_idf)
self.token_list = []
for idx in token_list_idx:
self.token_list += [self.tf_idf_keys[idx]]
self.token_ptr = len(self.token_list) - 1
logging.info("sampled token list: {:s}".format(
filter_unicode(" ".join(self.token_list))))
def word_level_augment(
examples, aug_ops, vocab, data_stats):
"""Word level augmentations. Used before augmentation."""
if aug_ops:
if aug_ops.startswith("unif"):
logging.info("\n>>Using augmentation {}".format(aug_ops))
token_prob = float(aug_ops.split("-")[1])
op = UnifRep(token_prob, vocab)
for i in range(len(examples)):
examples[i] = op(examples[i])
#tf_idf augmentation
elif aug_ops.startswith("tf_idf"):
logging.info("\n>>Using augmentation {}".format(aug_ops))
token_prob = float(aug_ops.split("-")[1])
op = TfIdfWordRep(token_prob, data_stats)
for i in range(len(examples)):
examples[i] = op(examples[i])
return examples