-
Notifications
You must be signed in to change notification settings - Fork 0
/
create.py
48 lines (34 loc) · 1.61 KB
/
create.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from argparse import ArgumentParser
import csv
import datasets
from tqdm import tqdm
def main(args):
""" Create ASNQ-challenging dataset. """
asnq = datasets.load_dataset('asnq')
split = asnq[args.split]
filters = None
if args.filter is not None:
with open(args.filter) as fi:
filters = [x.strip() for x in fi.readlines()]
identities_vocabulary = dict()
with open(args.output_file, "w") as fo:
writer = csv.writer(fo, delimiter="\t", quoting=csv.QUOTE_MINIMAL, quotechar='"')
for example in tqdm(split, total=len(split), desc="Processing..."):
if filters is not None and example['question'].strip() not in filters:
continue
question = example['question']
sentence = example['sentence']
label_number = example['sentence_in_long_answer'] * 2 + example['short_answer_in_sentence'] + 1
if args.reduced and label_number == 1:
continue
if not question in identities_vocabulary:
identities_vocabulary[question] = len(identities_vocabulary)
writer.writerow([identities_vocabulary[question], question, sentence, int(label_number == 4)])
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--split', choices=['train', 'validation'], default='train', type=str, required=False)
parser.add_argument('--output_file', type=str, required=True)
parser.add_argument('--filter', type=str, required=False, default=None)
parser.add_argument('--reduced', action="store_true")
args = parser.parse_args()
main(args)