for path in paths:
cur_dataset = torch.load(path)
logger.info("Loading dataset from %s, number of examples: %d" %
(path, len(cur_dataset)))
cur_dataset.fields = self.fields
cur_iter = OrderedIterator(
dataset=cur_dataset,
After Change
paths = self._paths
if self.is_train and self.repeat:
// Cycle through the shards indefinitely.
paths = cycle(paths)
for path in paths:
for batch in self._iter_dataset(path):
yield batch
num_batches += 1
if self.is_train and not self.repeat and \
num_batches % self.num_batches_multiple != 0:
// When the dataset is not repeated, we might need to ensure that
// the number of returned batches is the multiple of a given value.
// This is important for multi GPU training to ensure that all
// workers have the same number of batches to process.
for path in paths:
for batch in self._iter_dataset(path):
yield batch
num_batches += 1
if num_batches % self.num_batches_multiple == 0:
return
def max_tok_len(new, count, sofar):
In token batching scheme, the number of sequences is limited
such that the total number of src/tgt tokens (including padding)