g = self.gate_act(g)
begin_g = 0 // index of first scalar gate capsule
z = y.new_empty((y.size(0), size_out, y.size(2)))
begin_y = 0 // index of first capsule
for n, mul in enumerate(self.repr_out):
After Change
def forward(self, input, difference_matrix, relative_mask=None):
y = self.conv(input, difference_matrix, relative_mask) // [batch, channel, N]
if self.scalar_act is None and self.gate_act is None:
return y
has_batch = difference_matrix.dim() == 4
if not has_batch:
difference_matrix = difference_matrix.unsqueeze(0)
y = y.unsqueeze(0)
batch, N, _M, _ = difference_matrix.size()
size_out = sum(mul * (2 * l + 1) for l, mul in enumerate(self.repr_out))
if self.gate_act is not None:
g = y[:, size_out:]
g = self.gate_act(g)
begin_g = 0 // index of first scalar gate capsule
z = y.new_empty(batch, size_out, N)
begin_y = 0 // index of first capsule
for l, mul in enumerate(self.repr_out):
if mul == 0: