diff --git a/examples/bert/README.md b/examples/bert/README.md index 487d293c..3644c903 100644 --- a/examples/bert/README.md +++ b/examples/bert/README.md @@ -41,12 +41,12 @@ Run the following cmd to this end: python prepare_data.py --task=MRPC [--max_seq_length=128] [--vocab_file=bert_pretrained_models/uncased_L-12_H-768_A-12/vocab.txt] - [--tfrecords_output_dir=data/MRPC] + [--tfrecord_output_dir=data/MRPC] ``` - `task`: Specifies the dataset name to preprocess. BERT provides default support for `{'CoLA', 'MNLI', 'MRPC', 'XNLI', 'SST'}` data. - `max_seq_length`: The maxium length of sequence. This includes BERT special tokens that will be automatically added. Longer sequence will be trimmed. - `vocab_file`: Path to a vocabary file used for tokenization. -- `tfrecords_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`. +- `tfrecord_output_dir`: The output path where the resulting TFRecord files will be put in. Be default, it is set to `data/{task}` where `{task}` is the (upper-cased) dataset name specified in `--task` above. So in the above cmd, the TFRecord files are output to `data/MRPC`. **Outcome of the Preprocessing**: - The preprocessing will output 3 TFRecord data files `{train.tf_record, eval.tf_record, test.tf_record}` in the specified output directory. diff --git a/examples/bert/prepare_data.py b/examples/bert/prepare_data.py index 757fe49b..e93f94c8 100644 --- a/examples/bert/prepare_data.py +++ b/examples/bert/prepare_data.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Produces TFRecords files and modifies data configuration file +"""Produces TFRecord files and modifies data configuration file """ from __future__ import absolute_import @@ -41,8 +41,10 @@ "max_seq_length", 128, "The maxium length of sequence, longer sequence will be trimmed.") flags.DEFINE_string( - "tfrecords_output_dir", 'data/MRPC', - "The output directory where the TFRecords files will be generated.") + "tfrecord_output_dir", None, + "The output directory where the TFRecord files will be generated. " + "By default it will be set to 'data/{task}'. E.g.: if " + "task is 'MRPC', it will be set as 'data/MRPC'") flags.DEFINE_bool( "do_lower_case", True, "Whether to lower case the input text. Should be True for uncased " @@ -68,11 +70,11 @@ def prepare_data(): data_dir = 'data/{}'.format( task_datasets_rename[FLAGS.task]) - if FLAGS.tfrecords_output_dir is None: - tfrecords_output_dir = data_dir + if FLAGS.tfrecord_output_dir is None: + tfrecord_output_dir = data_dir else: - tfrecords_output_dir = FLAGS.tfrecords_output_dir - tx.utils.maybe_create_dir(tfrecords_output_dir) + tfrecord_output_dir = FLAGS.tfrecord_output_dir + tx.utils.maybe_create_dir(tfrecord_output_dir) processors = { "COLA": data_utils.ColaProcessor, @@ -91,13 +93,13 @@ def prepare_data(): vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case) - # Produces TFRecords files + # Produces TFRecord files data_utils.prepare_TFRecord_data( processor=processor, tokenizer=tokenizer, data_dir=data_dir, max_seq_length=FLAGS.max_seq_length, - output_dir=tfrecords_output_dir) + output_dir=tfrecord_output_dir) modify_config_data(FLAGS.max_seq_length, num_train_data, num_classes) def modify_config_data(max_seq_length, num_train_data, num_classes): diff --git a/examples/bert/utils/data_utils.py b/examples/bert/utils/data_utils.py index 30938714..f72f8c86 100644 --- a/examples/bert/utils/data_utils.py +++ b/examples/bert/utils/data_utils.py @@ -459,7 +459,7 @@ def prepare_TFRecord_data(processor, tokenizer, max_seq_length: Max sequence length. batch_size: mini-batch size. model: `train`, `eval` or `test`. - output_dir: The directory to save the TFRecords in. + output_dir: The directory to save the TFRecord in. """ label_list = processor.get_labels() diff --git a/texar/data/data/__init__.py b/texar/data/data/__init__.py index 153568e4..20f7b8d8 100644 --- a/texar/data/data/__init__.py +++ b/texar/data/data/__init__.py @@ -29,4 +29,4 @@ from texar.data.data.multi_aligned_data import * from texar.data.data.data_iterators import * from texar.data.data.dataset_utils import * -from texar.data.data.tfrecords_data import * +from texar.data.data.tfrecord_data import * diff --git a/texar/data/data/multi_aligned_data.py b/texar/data/data/multi_aligned_data.py index ba9393a8..76239cf2 100644 --- a/texar/data/data/multi_aligned_data.py +++ b/texar/data/data/multi_aligned_data.py @@ -29,10 +29,10 @@ from texar.utils.dtypes import is_str, is_callable from texar.data.data.text_data_base import TextDataBase from texar.data.data.scalar_data import ScalarData -from texar.data.data.tfrecords_data import TFRecordData +from texar.data.data.tfrecord_data import TFRecordData from texar.data.data.mono_text_data import _default_mono_text_dataset_hparams from texar.data.data.scalar_data import _default_scalar_dataset_hparams -from texar.data.data.tfrecords_data import _default_tfrecord_dataset_hparams +from texar.data.data.tfrecord_data import _default_tfrecord_dataset_hparams from texar.data.data.mono_text_data import MonoTextData from texar.data.data_utils import count_file_lines from texar.data.data import dataset_utils as dsutils @@ -132,7 +132,7 @@ class MultiAlignedData(TextDataBase): 'datasets': [ {'files': 'd.txt', 'vocab_file': 'v.d', 'data_name': 'm'}, { - 'files': 'd.tfrecords', + 'files': 'd.tfrecord', 'data_type': 'tf_record', "feature_original_types": { 'image': ['tf.string', 'FixedLenFeature'] diff --git a/texar/data/data/multi_aligned_data_test.py b/texar/data/data/multi_aligned_data_test.py index c80dc83e..2347dce7 100644 --- a/texar/data/data/multi_aligned_data_test.py +++ b/texar/data/data/multi_aligned_data_test.py @@ -80,15 +80,15 @@ def _int64_feature(value): feature = { "number1": _int64_feature(128), "number2": _int64_feature(512), - "text": _bytes_feature("This is a sentence for TFRecords 词 词 。") + "text": _bytes_feature("This is a sentence for TFRecord 词 词 。") } data_example = tf.train.Example( features=tf.train.Features(feature=feature)) - tfrecords_file = tempfile.NamedTemporaryFile(suffix=".tfrecords") - with tf.python_io.TFRecordWriter(tfrecords_file.name) as writer: + tfrecord_file = tempfile.NamedTemporaryFile(suffix=".tfrecord") + with tf.python_io.TFRecordWriter(tfrecord_file.name) as writer: writer.write(data_example.SerializeToString()) - tfrecords_file.flush() - self._tfrecords_file = tfrecords_file + tfrecord_file.flush() + self._tfrecord_file = tfrecord_file # Construct database self._hparams = { @@ -120,7 +120,7 @@ def _int64_feature(value): "data_name": "label" }, { # dataset 4 - "files": self._tfrecords_file.name, + "files": self._tfrecord_file.name, "feature_original_types": { 'number1': ['tf.int64', 'FixedLenFeature'], 'number2': ['tf.int64', 'FixedLenFeature'], diff --git a/texar/data/data/tfrecords_data.py b/texar/data/data/tfrecord_data.py similarity index 96% rename from texar/data/data/tfrecords_data.py rename to texar/data/data/tfrecord_data.py index 825a3836..aebccde4 100644 --- a/texar/data/data/tfrecords_data.py +++ b/texar/data/data/tfrecord_data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -Data class that supports reading TFRecords data and data type converting. +Data class that supports reading TFRecord data and data type converting. """ from __future__ import absolute_import @@ -105,7 +105,7 @@ class TFRecordData(DataBase): # # # 'image_raw' is a list of image data bytes in this # # example. - # 'image_raw': ['...'], + # 'image_raw': [...], # } # } @@ -211,13 +211,11 @@ def default_hparams(): .. code-block:: python - ... feature_original_types = { "input_ids": ["tf.int64", "FixedLenFeature", 128], "label_ids": ["tf.int64", "FixedLenFeature"], "name_lists": ["tf.string", "VarLenFeature"], } - ... "feature_convert_types" : dict, optional Specifies dtype converting after reading the data files. This @@ -238,12 +236,10 @@ def default_hparams(): .. code-block:: python - ... feature_convert_types = { "input_ids": "tf.int32", "label_ids": "tf.int32", } - ... "image_options" : dict, optional Specifies the image feature name and performs image resizing, @@ -277,27 +273,21 @@ def default_hparams(): .. code-block:: python - ... dataset: { ... "num_shards": 2, - "shard_id": 0, - ... + "shard_id": 0 } - ... For gpu 1: .. code-block:: python - ... dataset: { ... "num_shards": 2, - "shard_id": 1, - ... + "shard_id": 1 } - ... Also refer to `examples/bert` for a use case. diff --git a/texar/data/data/tfrecords_data_test.py b/texar/data/data/tfrecord_data_test.py similarity index 97% rename from texar/data/data/tfrecords_data_test.py rename to texar/data/data/tfrecord_data_test.py index a2129646..18dd6ecc 100644 --- a/texar/data/data/tfrecords_data_test.py +++ b/texar/data/data/tfrecord_data_test.py @@ -104,11 +104,11 @@ def _image_example(image_string, image_shape, label): cat_in_snow: (213, 320, 3), williamsburg_bridge: (239, 194), } - _tfrecords_filepath = os.path.join( + _tfrecord_filepath = os.path.join( self._test_dir, - 'test.tfrecords') + 'test.tfrecord') # Prepare Validation data - with tf.python_io.TFRecordWriter(_tfrecords_filepath) as writer: + with tf.python_io.TFRecordWriter(_tfrecord_filepath) as writer: for image_path, label in _toy_image_labels_valid.items(): with open(image_path, 'rb') as fid: @@ -136,7 +136,7 @@ def _image_example(image_string, image_shape, label): "batch_size": 1, "shuffle": False, "dataset": { - "files": _tfrecords_filepath, + "files": _tfrecord_filepath, "feature_original_types": _feature_original_types, "feature_convert_types": self._feature_convert_types, "image_options": [_image_options], diff --git a/texar/data/data_decoders.py b/texar/data/data_decoders.py index 3bb6f107..6ee931f6 100644 --- a/texar/data/data_decoders.py +++ b/texar/data/data_decoders.py @@ -582,7 +582,7 @@ def decode(self, data, items): items. Args: - data: The TFRecords data(serialized example) to decode. + data: The TFRecord data(serialized example) to decode. items: A list of strings, each of which is the name of the resulting tensors to retrieve. @@ -609,7 +609,7 @@ def decode(self, data, items): dtypes.get_tf_dtype(value[0]))}) decoded_data = tf.parse_single_example(data, feature_description) - # Handle TFRecords containing images + # Handle TFRecord containing images if isinstance(self._image_options, dict): self._decode_image_str_byte( self._image_options,