Skip to content

Commit

Permalink
adding a buffered iterator
Browse files Browse the repository at this point in the history
Summary:
Torch's DataLoader keeps a buffer of only 2 ready batches only, which cannot be changed. This causes a data loading bottleneck at times where data preparation time fluctuates.

Adding BufferedIterator, which is a generic wrapper for an iterator, implementing a buffer using queue.

Adding FairseqTask support, and in BatchSamplerIterator as default.

Reviewed By: myleott

Differential Revision: D21261026

fbshipit-source-id: 23d4bc6181fe1f9a7ee7ad7d18491594725c0f53
  • Loading branch information
Gil Keren authored and facebook-github-bot committed Apr 29, 2020
1 parent dd518ef commit 4115317
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 11 deletions.
93 changes: 82 additions & 11 deletions fairseq/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
import itertools
import math
import os

import time
import numpy as np
import torch

import queue
import logging
from threading import Thread
from . import data_utils

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# Object used by _background_consumer to signal the source is exhausted
# to the main thread.
_sentinel = object()


class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count.
Expand Down Expand Up @@ -178,11 +187,14 @@ class EpochBatchIterator(EpochBatchIterating):
(default: 0).
epoch (int, optional): the epoch to start the iterator from
(default: 1).
buffer_size (int, optional): the number of batches to keep ready in the
queue. Helps speeding up dataloading. When buffer_size is zero, the
default torch.utils.data.DataLoader preloading is used.
"""

def __init__(
self, dataset, collate_fn, batch_sampler, seed=1, num_shards=1, shard_id=0,
num_workers=0, epoch=1,
num_workers=0, epoch=1, buffer_size=0
):
assert isinstance(dataset, torch.utils.data.Dataset)
self.dataset = dataset
Expand All @@ -192,6 +204,7 @@ def __init__(
self.num_shards = num_shards
self.shard_id = shard_id
self.num_workers = num_workers
self.buffer_size = buffer_size

self.epoch = max(epoch, 1) # we use 1-based indexing for epochs
self.shuffle = True
Expand Down Expand Up @@ -307,16 +320,22 @@ def shuffle_batches(batches, seed):
if self.num_workers > 0:
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'

return CountingIterator(
torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
),
start=offset,
# Create data loader
itr = torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=batches[offset:],
num_workers=self.num_workers,
)

# Wrap with a BufferedIterator if needed
if self.buffer_size > 0:
itr = BufferedIterator(self.buffer_size, itr)

# Wrap with CoutingIterator
itr = CountingIterator(itr, start=offset)
return itr


class GroupedIterator(object):
"""Wrapper around an iterable that returns groups (chunks) of items.
Expand Down Expand Up @@ -382,3 +401,55 @@ def __iter__(self):

def __next__(self):
return next(self.itr)[1]


class BackgroundConsumer(Thread):
def __init__(self, queue, source):
Thread.__init__(self)

self._queue = queue
self._source = source

def run(self):
for item in self._source:
self._queue.put(item)

# Signal the consumer we are done.
self._queue.put(_sentinel)


class BufferedIterator(object):
def __init__(self, size, iterable):
self._queue = queue.Queue(size)
self._iterable = iterable

self._consumer = BackgroundConsumer(self._queue, iterable)
self._consumer.daemon = True
self._consumer.start()

self.start_time = time.time()
self.warning_time = None

def __iter__(self):
return self

def __len__(self):
return len(self._iterable)

def __next__(self):
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < 2:
if time.time() - self.start_time > 5 * 60:
if self.warning_time is None or time.time() - self.warning_time > 15 * 60:
logger.info(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers may help."
)
self.warning_time = time.time()

# Get next example
item = self._queue.get(True)
if item is _sentinel:
raise StopIteration()
return item
2 changes: 2 additions & 0 deletions fairseq/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def add_dataset_args(parser, train=False, gen=False):
parser.add_argument('--dataset-impl', metavar='FORMAT',
choices=get_available_dataset_impl(),
help='output dataset implementation')
group.add_argument('--data-buffer-size', default=0, type=int, metavar='N',
help='Number of batches to preload')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
help='data subset to use for training (e.g. train, valid, test)')
Expand Down
2 changes: 2 additions & 0 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_batch_iterator(
shard_id=0,
num_workers=0,
epoch=1,
buffer_size=0
):
"""
Get an iterator that yields batches of data from the given dataset.
Expand Down Expand Up @@ -191,6 +192,7 @@ def get_batch_iterator(
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
buffer_size=buffer_size,
)
self.dataset_to_epoch_iter[dataset] = epoch_iter
return epoch_iter
Expand Down
2 changes: 2 additions & 0 deletions fairseq/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def get_train_iterator(
shard_id=self.data_parallel_rank if shard_batch_itr else 0,
num_workers=self.args.num_workers,
epoch=epoch,
buffer_size=self.args.data_buffer_size,
)

def get_valid_iterator(
Expand All @@ -315,6 +316,7 @@ def get_valid_iterator(
num_shards=self.data_parallel_world_size,
shard_id=self.data_parallel_rank,
num_workers=self.args.num_workers,
buffer_size=self.args.data_buffer_size,
)

def begin_epoch(self, epoch):
Expand Down

0 comments on commit 4115317

Please sign in to comment.