-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add new reader and test script for SQUAD dataset * Add squad MRC dataset parser, including new ontology and testcase * format with balck; revise annotations * add squad data example * add new word to spelling checklist * replace MRCanswer with Phrase; substitute Passage with Document Co-authored-by: Hector <[email protected]>
- Loading branch information
1 parent
5913ef6
commit a365f76
Showing
8 changed files
with
253 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,3 +138,4 @@ embeddings | |
docstrings | ||
numpy | ||
jsonpickle | ||
crowdworkers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2019 The Forte Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright 2019 The Forte Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import json | ||
from typing import Any, Iterator, Dict, Set, Tuple | ||
|
||
from forte.data.data_pack import DataPack | ||
from forte.data.base_reader import PackReader | ||
from ft.onto.base_ontology import Document, MRCQuestion, Phrase | ||
from ftx.onto.race_qa import Passage | ||
|
||
__all__ = [ | ||
"SquadReader", | ||
] | ||
|
||
|
||
class SquadReader(PackReader): | ||
r"""Reader for processing Stanford Question Answering Dataset (SQuAD). | ||
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, | ||
consisting of questions posed by crowdworkers on a set of Wikipedia articles, | ||
where the answer to every question is a segment of text, or span. | ||
Dataset can be downloaded at https://rajpurkar.github.io/SQuAD-explorer/. | ||
SquadReader reads each paragraph in the dataset as a separate Document, and the questions | ||
are concatenated behind the paragraph, form a Passage. | ||
Phrase are MRC answers marked as text spans. Each MRCQuestion has a list of answers. | ||
""" | ||
|
||
def _collect(self, file_path: str) -> Iterator[Any]: # type: ignore | ||
r"""Given file_path to the dataset, return an iterator to every data point in it. | ||
Args: | ||
file_path: path to the JSON file | ||
Returns: QA pairs and the context of a paragraph of a passage in SQuAD dataset. | ||
""" | ||
with open(file_path, "r", encoding="utf8", errors="ignore") as file: | ||
jsonf = json.load(file) | ||
for dic in jsonf["data"]: | ||
title = dic["title"] | ||
cnt = 0 | ||
for qa_dic in dic["paragraphs"]: | ||
yield title + str(cnt), qa_dic["qas"], qa_dic["context"] | ||
cnt += 1 | ||
|
||
def _cache_key_function(self, text_file: str) -> str: | ||
return os.path.basename(text_file) | ||
|
||
def _parse_pack(self, qa_dict: Tuple[str, list, str]) -> Iterator[DataPack]: | ||
title, qas, context = qa_dict | ||
context_end = len(context) | ||
offset = context_end + 1 | ||
text = context | ||
|
||
pack = DataPack() # one datapack for a context | ||
for qa in qas: | ||
if qa["is_impossible"] is True: | ||
continue | ||
ques_text = qa["question"] | ||
ans = qa["answers"] | ||
text += "\n" + ques_text | ||
ques_end = offset + len(ques_text) | ||
question = MRCQuestion(pack, offset, ques_end) | ||
question.qid = qa["id"] | ||
offset = ques_end + 1 | ||
for a in ans: | ||
ans_text = a["text"] | ||
ans_start = a["answer_start"] | ||
answer = Phrase(pack, ans_start, ans_start + len(ans_text)) | ||
question.answers.append(answer) | ||
|
||
pack.set_text(text) | ||
|
||
passage = Passage(pack, 0, context_end) | ||
Document(pack, 0, len(pack.text)) | ||
|
||
passage.passage_id = title | ||
pack.pack_name = title | ||
yield pack | ||
|
||
@classmethod | ||
def default_configs(cls): | ||
return {"file_ext": ".txt"} | ||
|
||
def record(self, record_meta: Dict[str, Set[str]]): | ||
r"""Method to add output type record of `PlainTextReader` which is | ||
`ft.onto.base_ontology.Document` with an empty set | ||
to :attr:`forte.data.data_pack.Meta.record`. | ||
Args: | ||
record_meta: the field in the datapack for type record that need to | ||
fill in for consistency checking. | ||
""" | ||
record_meta["ft.onto.base_ontology.Document"] = set() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright 2019 The Forte Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Unit tests for SquadReader. | ||
""" | ||
import json | ||
import os | ||
import unittest | ||
from typing import Iterable | ||
|
||
from forte.data.data_pack import DataPack | ||
from forte.datasets.mrc.squad_reader import SquadReader | ||
from forte.pipeline import Pipeline | ||
from ft.onto.base_ontology import MRCQuestion | ||
from ftx.onto.race_qa import Passage | ||
|
||
|
||
class SquadReaderTest(unittest.TestCase): | ||
def setUp(self): | ||
self.dataset_path = os.path.abspath( | ||
os.path.join( | ||
os.path.dirname(os.path.realpath(__file__)), | ||
*([os.path.pardir] * 4), | ||
"data_samples/squad_v2.0/dev-v2.0-sample.json" | ||
) | ||
) | ||
|
||
def test_reader_no_replace_test(self): | ||
# Read with no replacements | ||
pipeline = Pipeline() | ||
reader = SquadReader() | ||
pipeline.set_reader(reader) | ||
pipeline.initialize() | ||
|
||
data_packs: Iterable[DataPack] = pipeline.process_dataset( | ||
self.dataset_path | ||
) | ||
file_path: str = self.dataset_path | ||
expected_file_dict = {} | ||
with open(file_path, "r", encoding="utf8", errors="ignore") as file: | ||
expected_json = json.load(file) | ||
for dic in expected_json["data"]: | ||
title = dic["title"] | ||
cnt = 0 | ||
for qa_dic in dic["paragraphs"]: | ||
expected_file_dict[ | ||
title + str(cnt) | ||
] = qa_dic # qas, context | ||
cnt += 1 | ||
|
||
count_packs = 0 | ||
for pack in data_packs: | ||
count_packs += 1 | ||
expected_text: str = "" | ||
expected = expected_file_dict[pack.pack_name] | ||
|
||
passage = list(pack.get(Passage)) | ||
self.assertEqual(len(passage), 1) | ||
expected_context = expected["context"] | ||
self.assertEqual(passage[0].text, expected_context) | ||
expected_text += expected_context | ||
|
||
for qid, question in enumerate(pack.get(MRCQuestion)): | ||
expected_qa = expected["qas"][qid] | ||
expected_question = expected_qa["question"] | ||
expected_answers = expected_qa["answers"] | ||
self.assertEqual(question.text, expected_question) | ||
if not isinstance(expected_answers, list): | ||
expected_answers = [expected_answers] | ||
answers = question.answers | ||
|
||
for answer, expected_answer in zip(answers, expected_answers): | ||
self.assertEqual(answer.text, expected_answer["text"]) | ||
expected_text += "\n" + expected_question | ||
|
||
self.assertEqual(pack.text, expected_text) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |