order[i*batch_size:(i+1)*batch_size]])
yield list(itertools.chain.from_iterable(
_dialog(dialog_indices[o])\
for o in order[i*batch_size:(i+1)*batch_size]))
@staticmethod
def _dialog_indices(data):
After Change
@overrides
def batch_generator(self, batch_size: int, data_type: str = "train", shuffle: bool = True) -> Generator:
if batch_size != 1:
raise RuntimeError("Dialogs currently only support batch size of 1")
dialogs = self._dialogs(self.data[data_type])
num_dialogs = len(dialogs)
order = list(range(num_dialogs))
if shuffle: