t = x.resize(x.size()[0], x.size()[1], 1, n)
mean = torch.mean(t, 3).repeat(1, 1, x.size()[2], x.size()[3])
// Calculate the biased var. torch.var returns unbiased var
var = torch.var(t, 3).repeat(1, 1, x.size()[2], x.size()[3]) * ((n - 1) / float(n))
res = (x - mean) / torch.sqrt(var + 1e-5)
// TODO: Check if you need to add scaling and shifting here
return res
After Change
mean = torch.mean(t, 3).expand_as(x)
// Calculate the biased var. torch.var returns unbiased var
var = torch.var(t, 3).expand_as(x)
scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
scale_broadcast = scale_broadcast.expand_as(x)
shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
shift_broadcast = shift_broadcast.expand_as(x)