Skip to content

Commit

Permalink
merge pull 120
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhitingHu committed Apr 5, 2019
2 parents d7a7385 + 5cc5126 commit 28f9172
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 42 deletions.
4 changes: 2 additions & 2 deletions examples/bert/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ Run the following cmd to this end:
python prepare_data.py --task=MRPC
[--max_seq_length=128]
[--vocab_file=bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt]
[--tfrecords_output_dir=data/MRPC]
[--tfrecord_output_dir=data/MRPC]
```
- `task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data.
- `max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed.
- `vocab_file`: Path to a vocabary file used for tokenization.
- `tfrecords_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`.
- `tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`.

**Outcome of the Preprocessing**:
- The preprocessing will output 3 TFRecord data files `{train.tf_record, eval.tf_record, test.tf_record}` in the specified output directory.
Expand Down
20 changes: 11 additions & 9 deletions examples/bert/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Produces TFRecords files and modifies data configuration file
"""Produces TFRecord files and modifies data configuration file
"""

from __future__ import absolute_import
Expand Down Expand Up @@ -41,8 +41,10 @@
"max_seq_length", 128,
"The maxium length of sequence, longer sequence will be trimmed.")
flags.DEFINE_string(
"tfrecords_output_dir", 'data/MRPC',
"The output directory where the TFRecords files will be generated.")
"tfrecord_output_dir", None,
"The output directory where the TFRecord files will be generated. "
"By default it will be set to 'data/{task}'. E.g.: if "
"task is 'MRPC', it will be set as 'data/MRPC'")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
Expand All @@ -68,11 +70,11 @@ def prepare_data():
data_dir = 'data/{}'.format(
task_datasets_rename[FLAGS.task])

if FLAGS.tfrecords_output_dir is None:
tfrecords_output_dir = data_dir
if FLAGS.tfrecord_output_dir is None:
tfrecord_output_dir = data_dir
else:
tfrecords_output_dir = FLAGS.tfrecords_output_dir
tx.utils.maybe_create_dir(tfrecords_output_dir)
tfrecord_output_dir = FLAGS.tfrecord_output_dir
tx.utils.maybe_create_dir(tfrecord_output_dir)

processors = {
"COLA": data_utils.ColaProcessor,
Expand All @@ -91,13 +93,13 @@ def prepare_data():
vocab_file=FLAGS.vocab_file,
do_lower_case=FLAGS.do_lower_case)

# Produces TFRecords files
# Produces TFRecord files
data_utils.prepare_TFRecord_data(
processor=processor,
tokenizer=tokenizer,
data_dir=data_dir,
max_seq_length=FLAGS.max_seq_length,
output_dir=tfrecords_output_dir)
output_dir=tfrecord_output_dir)
modify_config_data(FLAGS.max_seq_length, num_train_data, num_classes)

def modify_config_data(max_seq_length, num_train_data, num_classes):
Expand Down
2 changes: 1 addition & 1 deletion examples/bert/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def prepare_TFRecord_data(processor, tokenizer,
max_seq_length: Max sequence length.
batch_size: mini-batch size.
model: `train`, `eval` or `test`.
output_dir: The directory to save the TFRecords in.
output_dir: The directory to save the TFRecord in.
"""
label_list = processor.get_labels()

Expand Down
2 changes: 1 addition & 1 deletion texar/data/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
from texar.data.data.multi_aligned_data import *
from texar.data.data.data_iterators import *
from texar.data.data.dataset_utils import *
from texar.data.data.tfrecords_data import *
from texar.data.data.tfrecord_data import *
6 changes: 3 additions & 3 deletions texar/data/data/multi_aligned_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from texar.utils.dtypes import is_str, is_callable
from texar.data.data.text_data_base import TextDataBase
from texar.data.data.scalar_data import ScalarData
from texar.data.data.tfrecords_data import TFRecordData
from texar.data.data.tfrecord_data import TFRecordData
from texar.data.data.mono_text_data import _default_mono_text_dataset_hparams
from texar.data.data.scalar_data import _default_scalar_dataset_hparams
from texar.data.data.tfrecords_data import _default_tfrecord_dataset_hparams
from texar.data.data.tfrecord_data import _default_tfrecord_dataset_hparams
from texar.data.data.mono_text_data import MonoTextData
from texar.data.data_utils import count_file_lines
from texar.data.data import dataset_utils as dsutils
Expand Down Expand Up @@ -132,7 +132,7 @@ class MultiAlignedData(TextDataBase):
'datasets': [
{'files': 'd.txt', 'vocab_file': 'v.d', 'data_name': 'm'},
{
'files': 'd.tfrecords',
'files': 'd.tfrecord',
'data_type': 'tf_record',
"feature_original_types": {
'image': ['tf.string', 'FixedLenFeature']
Expand Down
12 changes: 6 additions & 6 deletions texar/data/data/multi_aligned_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def _int64_feature(value):
feature = {
"number1": _int64_feature(128),
"number2": _int64_feature(512),
"text": _bytes_feature("This is a sentence for TFRecords 词 词 。")
"text": _bytes_feature("This is a sentence for TFRecord 词 词 。")
}
data_example = tf.train.Example(
features=tf.train.Features(feature=feature))
tfrecords_file = tempfile.NamedTemporaryFile(suffix=".tfrecords")
with tf.python_io.TFRecordWriter(tfrecords_file.name) as writer:
tfrecord_file = tempfile.NamedTemporaryFile(suffix=".tfrecord")
with tf.python_io.TFRecordWriter(tfrecord_file.name) as writer:
writer.write(data_example.SerializeToString())
tfrecords_file.flush()
self._tfrecords_file = tfrecords_file
tfrecord_file.flush()
self._tfrecord_file = tfrecord_file

# Construct database
self._hparams = {
Expand Down Expand Up @@ -120,7 +120,7 @@ def _int64_feature(value):
"data_name": "label"
},
{ # dataset 4
"files": self._tfrecords_file.name,
"files": self._tfrecord_file.name,
"feature_original_types": {
'number1': ['tf.int64', 'FixedLenFeature'],
'number2': ['tf.int64', 'FixedLenFeature'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data class that supports reading TFRecords data and data type converting.
Data class that supports reading TFRecord data and data type converting.
"""

from __future__ import absolute_import
Expand Down Expand Up @@ -105,7 +105,7 @@ class TFRecordData(DataBase):
#
# # 'image_raw' is a list of image data bytes in this
# # example.
# 'image_raw': ['...'],
# 'image_raw': [...],
# }
# }
Expand Down Expand Up @@ -211,13 +211,11 @@ def default_hparams():
.. code-block:: python
...
feature_original_types = {
"input_ids": ["tf.int64", "FixedLenFeature", 128],
"label_ids": ["tf.int64", "FixedLenFeature"],
"name_lists": ["tf.string", "VarLenFeature"],
}
...
"feature_convert_types" : dict, optional
Specifies dtype converting after reading the data files. This
Expand All @@ -238,12 +236,10 @@ def default_hparams():
.. code-block:: python
...
feature_convert_types = {
"input_ids": "tf.int32",
"label_ids": "tf.int32",
}
...
"image_options" : dict, optional
Specifies the image feature name and performs image resizing,
Expand Down Expand Up @@ -277,27 +273,21 @@ def default_hparams():
.. code-block:: python
...
dataset: {
...
"num_shards": 2,
"shard_id": 0,
...
"shard_id": 0
}
...
For gpu 1:
.. code-block:: python
...
dataset: {
...
"num_shards": 2,
"shard_id": 1,
...
"shard_id": 1
}
...
Also refer to `examples/bert` for a use case.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def _image_example(image_string, image_shape, label):
cat_in_snow: (213, 320, 3),
williamsburg_bridge: (239, 194),
}
_tfrecords_filepath = os.path.join(
_tfrecord_filepath = os.path.join(
self._test_dir,
'test.tfrecords')
'test.tfrecord')
# Prepare Validation data
with tf.python_io.TFRecordWriter(_tfrecords_filepath) as writer:
with tf.python_io.TFRecordWriter(_tfrecord_filepath) as writer:
for image_path, label in _toy_image_labels_valid.items():

with open(image_path, 'rb') as fid:
Expand Down Expand Up @@ -136,7 +136,7 @@ def _image_example(image_string, image_shape, label):
"batch_size": 1,
"shuffle": False,
"dataset": {
"files": _tfrecords_filepath,
"files": _tfrecord_filepath,
"feature_original_types": _feature_original_types,
"feature_convert_types": self._feature_convert_types,
"image_options": [_image_options],
Expand Down
4 changes: 2 additions & 2 deletions texar/data/data_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def decode(self, data, items):
items.
Args:
data: The TFRecords data(serialized example) to decode.
data: The TFRecord data(serialized example) to decode.
items: A list of strings, each of which is the name of the resulting
tensors to retrieve.
Expand All @@ -609,7 +609,7 @@ def decode(self, data, items):
dtypes.get_tf_dtype(value[0]))})
decoded_data = tf.parse_single_example(data, feature_description)

# Handle TFRecords containing images
# Handle TFRecord containing images
if isinstance(self._image_options, dict):
self._decode_image_str_byte(
self._image_options,
Expand Down

0 comments on commit 28f9172

Please sign in to comment.