diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index 46073e5f31..613d41fc77 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -6,12 +6,21 @@ import itertools import math import os - +import time import numpy as np import torch - +import queue +import logging +from threading import Thread from . import data_utils +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# Object used by _background_consumer to signal the source is exhausted +# to the main thread. +_sentinel = object() + class CountingIterator(object): """Wrapper around an iterable that maintains the iteration count. @@ -178,11 +187,14 @@ class EpochBatchIterator(EpochBatchIterating): (default: 0). epoch (int, optional): the epoch to start the iterator from (default: 1). + buffer_size (int, optional): the number of batches to keep ready in the + queue. Helps speeding up dataloading. When buffer_size is zero, the + default torch.utils.data.DataLoader preloading is used. """ def __init__( self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0, - num_workers=0, epoch=1, + num_workers=0, epoch=1, buffer_size=0 ): assert isinstance(dataset, torch.utils.data.Dataset) self.dataset = dataset @@ -192,6 +204,7 @@ def __init__( self.num_shards = num_shards self.shard_id = shard_id self.num_workers = num_workers + self.buffer_size = buffer_size self.epoch = max(epoch, 1) # we use 1-based indexing for epochs self.shuffle = True @@ -307,16 +320,22 @@ def shuffle_batches(batches, seed): if self.num_workers > 0: os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning' - return CountingIterator( - torch.utils.data.DataLoader( - self.dataset, - collate_fn=self.collate_fn, - batch_sampler=batches[offset:], - num_workers=self.num_workers, - ), - start=offset, + # Create data loader + itr = torch.utils.data.DataLoader( + self.dataset, + collate_fn=self.collate_fn, + batch_sampler=batches[offset:], + num_workers=self.num_workers, ) + # Wrap with a BufferedIterator if needed + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # Wrap with CoutingIterator + itr = CountingIterator(itr, start=offset) + return itr + class GroupedIterator(object): """Wrapper around an iterable that returns groups (chunks) of items. @@ -382,3 +401,55 @@ def __iter__(self): def __next__(self): return next(self.itr)[1] + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source): + Thread.__init__(self) + + self._queue = queue + self._source = source + + def run(self): + for item in self._source: + self._queue.put(item) + + # Signal the consumer we are done. + self._queue.put(_sentinel) + + +class BufferedIterator(object): + def __init__(self, size, iterable): + self._queue = queue.Queue(size) + self._iterable = iterable + + self._consumer = BackgroundConsumer(self._queue, iterable) + self._consumer.daemon = True + self._consumer.start() + + self.start_time = time.time() + self.warning_time = None + + def __iter__(self): + return self + + def __len__(self): + return len(self._iterable) + + def __next__(self): + # Notify the user if there is a data loading bottleneck + if self._queue.qsize() < 2: + if time.time() - self.start_time > 5 * 60: + if self.warning_time is None or time.time() - self.warning_time > 15 * 60: + logger.info( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers may help." + ) + self.warning_time = time.time() + + # Get next example + item = self._queue.get(True) + if item is _sentinel: + raise StopIteration() + return item diff --git a/fairseq/options.py b/fairseq/options.py index b54e6bd761..66bec3b339 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -317,6 +317,8 @@ def add_dataset_args(parser, train=False, gen=False): parser.add_argument('--dataset-impl', metavar='FORMAT', choices=get_available_dataset_impl(), help='output dataset implementation') + group.add_argument('--data-buffer-size', default=0, type=int, metavar='N', + help='Number of batches to preload') if train: group.add_argument('--train-subset', default='train', metavar='SPLIT', help='data subset to use for training (e.g. train, valid, test)') diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index dd27727050..69c2e69d68 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -117,6 +117,7 @@ def get_batch_iterator( shard_id=0, num_workers=0, epoch=1, + buffer_size=0 ): """ Get an iterator that yields batches of data from the given dataset. @@ -191,6 +192,7 @@ def get_batch_iterator( shard_id=shard_id, num_workers=num_workers, epoch=epoch, + buffer_size=buffer_size, ) self.dataset_to_epoch_iter[dataset] = epoch_iter return epoch_iter diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a2f9d4c7db..9fd3af455c 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -294,6 +294,7 @@ def get_train_iterator( shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.args.num_workers, epoch=epoch, + buffer_size=self.args.data_buffer_size, ) def get_valid_iterator( @@ -315,6 +316,7 @@ def get_valid_iterator( num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, num_workers=self.args.num_workers, + buffer_size=self.args.data_buffer_size, ) def begin_epoch(self, epoch):