def test_fast(L, B, N, train=False):
reset_seed()
a = Variable(torch.randn(L, B, N).float().cuda())
c = Variable(torch.zeros(B, N).float().cuda())
cell = SRUCell(N, N, dropout=0, use_tanh=False).cuda()
if train: cell.train()
torch.cuda.synchronize()
start = time.time()
for i in range(T):
out = cell(a, c)
out[0].sum().backward()
torch.cuda.synchronize()
print ("test_fast: {:.6f}".format(
(time.time()-start)/T
))
After Change
def test_fast(L, B, N, train=False):
reset_seed()
a = Variable(torch.randn(L, B, N).float().cuda()*0.1)
c = Variable(torch.zeros(B, N).float().cuda())
cell = SRUCell(N, N, dropout=0, use_tanh=False).cuda()
if train: cell.train()
torch.cuda.synchronize()
start = time.time()
tot = 0
for i in range(T):
out = cell(a, c)
tot += out[0].data[-1,-1,-1]if train:
cell.zero_grad()
out[0].mean().backward()
torch.cuda.synchronize()
print ("test_fast: {:.6f}".format(
(time.time()-start)/T
))