device = input.device
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return scatter_row(x, index, input).view(batch_size, max_len, *old_shape[1:])
After Change
old_shape = input.shape
device = input.device
if not is_tensor(lengths):
lengths = th.tensor(lengths, dtype=th.int64, device=device)
else:
lengths = lengths.to(device)
max_len = as_scalar(lengths.max())
if l_min is not None:
max_len = builtins.max(max_len, l_min)
batch_size = len(lengths)
x = input.new(batch_size * max_len, *old_shape[1:])
x.fill_(value)
index = th.ones(len(input), dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1])
index = th.cumsum(index, 0) - 1
x[index] = input
return x.view(batch_size, max_len, *old_shape[1:])